diff --git a/torchbiggraph/train.py b/torchbiggraph/train.py index 2ed2b9c9..c920e69a 100644 --- a/torchbiggraph/train.py +++ b/torchbiggraph/train.py @@ -89,6 +89,7 @@ ) + logger = logging.getLogger("torchbiggraph") dist_logger = logging.LoggerAdapter(logger, {"distributed": True}) @@ -634,34 +635,43 @@ def train_and_report_stats( bucket_logger.debug("Shuffling edges") # Fix a seed to get the same permutation every time; have it # depend on all and only what affects the set of edges. - g = torch.Generator() - g.manual_seed( - hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs)) - ) + # Note: for the sake of efficiency, we sample eval edge idxs + # from the edge set *with replacement*, meaning that there may + # be duplicates of the same edge in the eval set. When we swap + # edges into the eval set, if there are duplicates then all + # but one will be clobbered. These collisions are unlikely + # if eval_fraction is small. + # + # Importantly, this eval sampling strategy is theoretically + # sound: + # * Training and eval sets are (exactly) disjoint + # * Eval set may have (rare) duplicates, but they are + # uniformly sampled so it's still an unbiased estimator + # of the out-of-sample statistics num_eval_edges = int(num_edges * config.eval_fraction) if num_eval_edges > 0: - edge_perm = torch.randperm(num_edges, generator=g) - eval_edge_perm = edge_perm[-num_eval_edges:] - num_edges -= num_eval_edges - edge_perm = edge_perm[torch.randperm(num_edges)] + g = torch.Generator() + g.manual_seed( + hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs)) + ) + eval_edge_idxs = torch.randint(num_edges, (num_eval_edges,), dtype=torch.long, generator=g) else: - eval_edge_perm = None - edge_perm = torch.randperm(num_edges) + eval_edge_idxs = None logger.info(f"Shuffled edges in {time.time() - tic} seconds") # HOGWILD evaluation before training - eval_stats_before = self._coordinate_eval(edges, eval_edge_perm) + eval_stats_before = self._coordinate_eval(edges, eval_edge_idxs) if eval_stats_before is not None: bucket_logger.info(f"Stats before training: {eval_stats_before}") # HOGWILD training bucket_logger.debug("Waiting for workers to perform training") - stats = self._coordinate_train(edges, edge_perm, epoch_idx) + stats = self._coordinate_train(edges, eval_edge_idxs, epoch_idx) # HOGWILD evaluation after training - eval_stats_after = self._coordinate_eval(edges, eval_edge_perm) + eval_stats_after = self._coordinate_eval(edges, eval_edge_idxs) if eval_stats_after is not None: bucket_logger.info(f"Stats before training: {eval_stats_after}") @@ -814,9 +824,18 @@ def _swap_partitioned_embeddings( return io_bytes - def _coordinate_train(self, edges, edge_perm, epoch_idx) -> Stats: + def _coordinate_train(self, edges, eval_edge_idxs, epoch_idx) -> Stats: assert self.config.num_gpus == 0, "GPU training not supported" + if eval_edge_idxs is not None: + num_train_edges = len(edges) - len(eval_edge_idxs) + train_edge_idxs = torch.arange(len(edges)) + train_edge_idxs[eval_edge_idxs] = torch.arange(num_train_edges, len(edges)) + train_edge_idxs = train_edge_idxs[:num_train_edges] + edge_perm = train_edge_idxs[torch.randperm(num_train_edges)] + else: + edge_perm = torch.randperm(len(edges)) + future_all_stats = self.pool.map_async(call, [ partial( process_in_batches, @@ -836,12 +855,12 @@ def _coordinate_train(self, edges, edge_perm, epoch_idx) -> Stats: all_stats = get_async_result(future_all_stats, self.pool) return Stats.sum(all_stats).average() - def _coordinate_eval(self, edges, eval_edge_perm) -> Optional[Stats]: + def _coordinate_eval(self, edges, eval_edge_idxs) -> Optional[Stats]: eval_batch_size = round_up_to_nearest_multiple( self.config.batch_size, self.config.eval_num_batch_negs ) - if eval_edge_perm is not None: + if eval_edge_idxs is not None: self.bucket_logger.debug("Waiting for workers to perform evaluation") future_all_eval_stats = self.pool.map_async(call, [ partial( @@ -850,9 +869,9 @@ def _coordinate_eval(self, edges, eval_edge_perm) -> Optional[Stats]: model=self.model, batch_processor=self.evaluator, edges=edges, - indices=eval_edge_perm[s], + indices=eval_edge_idxs[s], ) - for s in split_almost_equally(eval_edge_perm.size(0), + for s in split_almost_equally(eval_edge_idxs.size(0), num_parts=self.num_workers) ]) all_eval_stats = \