Skip to content

Commit

Permalink
Don't try to checkpoint unpartitioned entities from the partition server
Browse files Browse the repository at this point in the history
Reviewed By: lw

Differential Revision: D20758139

fbshipit-source-id: ac5493e2968f8725d8dbbf32475b8dbe0e90e3b8
  • Loading branch information
adamlerer authored and facebook-github-bot committed Apr 1, 2020
1 parent 5b848ab commit 896dffc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
10 changes: 8 additions & 2 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def test_partitioned(self):
self.assertCheckpointWritten(train_config, version=1)
do_eval(eval_config, subprocess_init=self.subprocess_init)

def test_distributed(self):
def _test_distributed(self, num_partitions):
sync_path = TemporaryDirectory()
self.addCleanup(sync_path.cleanup)
entity_name = "e"
Expand All @@ -495,7 +495,7 @@ def test_distributed(self):
base_config = ConfigSchema(
dimension=10,
relations=[relation_config],
entities={entity_name: EntitySchema(num_partitions=4)},
entities={entity_name: EntitySchema(num_partitions=num_partitions)},
entity_path=None, # filled in later
edge_paths=[], # filled in later
checkpoint_path=self.checkpoint_path.name,
Expand Down Expand Up @@ -547,6 +547,12 @@ def test_distributed(self):
done[1] = True
self.assertCheckpointWritten(train_config, version=1)

def test_distributed(self):
self._test_distributed(num_partitions=4)

def test_distributed_unpartitioned(self):
self._test_distributed(num_partitions=1)

def test_distributed_with_partition_servers(self):
sync_path = TemporaryDirectory()
self.addCleanup(sync_path.cleanup)
Expand Down
9 changes: 8 additions & 1 deletion torchbiggraph/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def write(
embs: FloatTensorType,
optim_state: Optional[OptimizerStateDict],
force_clean: bool = False,
unpartitioned: bool = False,
) -> None:
if not force_clean:
self.dirty.add((entity, part))
Expand All @@ -271,7 +272,7 @@ def write(
metadata = self.collect_metadata()
serialized_optim_state = serialize_optim_state(optim_state)

if self.partition_client is not None:
if self.partition_client is not None and not unpartitioned:
self.partition_client.store(entity, part, embs, serialized_optim_state)
else:
self.storage.save_entity_partition(
Expand Down Expand Up @@ -381,6 +382,12 @@ def write_new_version(
if self.partition_client is not None:
for entity, econf in config.entities.items():
dimension = config.entity_dimension(entity)

if econf.num_partitions == 1:
# unpartitioned entities are not stored on the partition
# server; they are checkpointed separately in train.py
continue

for part in range(self.rank, econf.num_partitions, self.num_machines):
logger.debug(f"Getting {entity} {part}")
count = entity_counts[entity][part]
Expand Down
7 changes: 6 additions & 1 deletion torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,10 +927,15 @@ def _maybe_write_checkpoint(
if self.rank == 0:
for entity, econfig in config.entities.items():
if econfig.num_partitions == 1:
logger.info(f"Writing {entity} embeddings")
embs = self.holder.unpartitioned_embeddings[entity]
optimizer = self.trainer.unpartitioned_optimizers[entity]
self.checkpoint_manager.write(
entity, UNPARTITIONED, embs.detach(), optimizer.state_dict()
entity,
UNPARTITIONED,
embs.detach(),
optimizer.state_dict(),
unpartitioned=True,
)

logger.info("Writing the metadata")
Expand Down

0 comments on commit 896dffc

Please sign in to comment.