Skip to content

Commit

Permalink
Allow for single-partition training on 1 GPU machine (simple version)
Browse files Browse the repository at this point in the history
Summary: GPU training was written to not support "unpartitioned embeddings", but that precludes an important case where you have 1 partition and 1 machine. In that case, you don't have to do any of the special stuff for partitioned (swapping) nor for unpartitioned (sharing). This case is important for e.g. testing with small graphs.

Reviewed By: lw

Differential Revision: D21302580

fbshipit-source-id: 33fccdcc80ff06398e4a07a2fd6d2a654830045e
  • Loading branch information
adamlerer authored and facebook-github-bot committed May 4, 2020
1 parent 9ae8592 commit b7716f3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 65 deletions.
20 changes: 14 additions & 6 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,11 @@ def test_gpu(self):
def test_gpu_half(self):
self._test_gpu(do_half_precision=True)

def _test_gpu(self, do_half_precision=False):
@unittest.skipIf(not torch.cuda.is_available(), "No GPU")
def test_gpu_1partition(self):
self._test_gpu(num_partitions=1)

def _test_gpu(self, do_half_precision=False, num_partitions=2):
entity_name = "e"
relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name)
base_config = ConfigSchema(
Expand All @@ -507,7 +511,7 @@ def _test_gpu(self, do_half_precision=False):
num_batch_negs=64,
num_uniform_negs=64,
relations=[relation_config],
entities={entity_name: EntitySchema(num_partitions=2)},
entities={entity_name: EntitySchema(num_partitions=num_partitions)},
entity_path=None, # filled in later
edge_paths=[], # filled in later
checkpoint_path=self.checkpoint_path.name,
Expand Down Expand Up @@ -536,17 +540,21 @@ def _test_gpu(self, do_half_precision=False):
def _test_distributed(self, num_partitions):
sync_path = TemporaryDirectory()
self.addCleanup(sync_path.cleanup)
entity_name = "e"
e1 = "e1"
e2 = "e2"
relation_config = RelationSchema(
name="r",
lhs=entity_name,
rhs=entity_name,
lhs=e1,
rhs=e2,
operator="linear", # To exercise the parameter server.
)
base_config = ConfigSchema(
dimension=10,
relations=[relation_config],
entities={entity_name: EntitySchema(num_partitions=num_partitions)},
entities={
e1: EntitySchema(num_partitions=num_partitions),
e2: EntitySchema(num_partitions=4),
},
entity_path=None, # filled in later
edge_paths=[], # filled in later
checkpoint_path=self.checkpoint_path.name,
Expand Down
22 changes: 10 additions & 12 deletions torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,18 +919,16 @@ def _maybe_write_checkpoint(
f"{self.iteration_manager.num_edge_chunks}"
)
if self.rank == 0:
for entity, econfig in config.entities.items():
if econfig.num_partitions == 1:
logger.info(f"Writing {entity} embeddings")
embs = self.holder.unpartitioned_embeddings[entity]
optimizer = self.trainer.unpartitioned_optimizers[entity]
self.checkpoint_manager.write(
entity,
UNPARTITIONED,
embs.detach(),
optimizer.state_dict(),
unpartitioned=True,
)
for entity, embs in self.holder.unpartitioned_embeddings.items():
logger.info(f"Writing {entity} embeddings")
optimizer = self.trainer.unpartitioned_optimizers[entity]
self.checkpoint_manager.write(
entity,
UNPARTITIONED,
embs.detach(),
optimizer.state_dict(),
unpartitioned=True,
)

logger.info("Writing the metadata")
state_dict: ModuleStateDict = self.model.state_dict()
Expand Down
61 changes: 14 additions & 47 deletions torchbiggraph/train_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,18 @@
import argparse
import ctypes
import logging
import math
import os
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import partial
from multiprocessing.connection import wait as mp_wait
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
NamedTuple,
Optional,
Set,
Tuple,
)
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple

import torch
import torch.distributed as td
import torch.multiprocessing as mp
from torch.optim import Adagrad, Optimizer
from torchbiggraph import _C
from torchbiggraph.batching import AbstractBatchProcessor, process_in_batches
from torchbiggraph.bucket_scheduling import (
AbstractBucketScheduler,
BucketStats,
DistributedBucketScheduler,
LockServer,
SingleMachineBucketScheduler,
)
from torchbiggraph.checkpoint_manager import (
CheckpointManager,
ConfigMetadataProvider,
MetadataProvider,
PartitionClient,
)
from torchbiggraph.config import (
ConfigFileLoader,
ConfigSchema,
RelationSchema,
add_to_sys_path,
)
from torchbiggraph.config import ConfigFileLoader, ConfigSchema, add_to_sys_path
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.entitylist import EntityList
from torchbiggraph.graph_storages import EDGE_STORAGES, ENTITY_STORAGES
Expand Down Expand Up @@ -124,8 +91,8 @@ def unaccounted(self) -> float:


class SubprocessArgs(NamedTuple):
lhs_partitioned_types: Set[str]
rhs_partitioned_types: Set[str]
lhs_types: Set[str]
rhs_types: Set[str]
lhs_part: Partition
rhs_part: Partition
lhs_subpart: SubPartition
Expand Down Expand Up @@ -194,8 +161,8 @@ def run(self) -> None:
break

stats = self.do_one_job(
lhs_partitioned_types=job.lhs_partitioned_types,
rhs_partitioned_types=job.rhs_partitioned_types,
lhs_types=job.lhs_types,
rhs_types=job.rhs_types,
lhs_part=job.lhs_part,
rhs_part=job.rhs_part,
lhs_subpart=job.lhs_subpart,
Expand All @@ -217,8 +184,8 @@ def run(self) -> None:

def do_one_job( # noqa
self,
lhs_partitioned_types: Set[str],
rhs_partitioned_types: Set[str],
lhs_types: Set[str],
rhs_types: Set[str],
lhs_part: Partition,
rhs_part: Partition,
lhs_subpart: SubPartition,
Expand All @@ -243,9 +210,9 @@ def do_one_job( # noqa
occurrences: Dict[
Tuple[EntityName, Partition, SubPartition], Set[Side]
] = defaultdict(set)
for entity_name in lhs_partitioned_types:
for entity_name in lhs_types:
occurrences[entity_name, lhs_part, lhs_subpart].add(Side.LHS)
for entity_name in rhs_partitioned_types:
for entity_name in rhs_types:
occurrences[entity_name, rhs_part, rhs_subpart].add(Side.RHS)

if lhs_part != rhs_part: # Bipartite
Expand Down Expand Up @@ -323,10 +290,10 @@ def do_one_job( # noqa
Tuple[EntityName, Partition, SubPartition], Set[Side]
] = defaultdict(set)
if next_lhs_subpart is not None:
for entity_name in lhs_partitioned_types:
for entity_name in lhs_types:
next_occurrences[entity_name, lhs_part, next_lhs_subpart].add(Side.LHS)
if next_rhs_subpart is not None:
for entity_name in rhs_partitioned_types:
for entity_name in rhs_types:
next_occurrences[entity_name, rhs_part, next_rhs_subpart].add(Side.RHS)

tk.start("copy_from_device")
Expand Down Expand Up @@ -606,8 +573,8 @@ def schedule(gpu_idx: GPURank) -> None:
self.gpu_pool.schedule(
gpu_idx,
SubprocessArgs(
lhs_partitioned_types=holder.lhs_partitioned_types,
rhs_partitioned_types=holder.rhs_partitioned_types,
lhs_types=holder.lhs_partitioned_types,
rhs_types=holder.rhs_partitioned_types,
lhs_part=cur_b.lhs,
rhs_part=cur_b.rhs,
lhs_subpart=this_bucket[0],
Expand Down
9 changes: 9 additions & 0 deletions torchbiggraph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,15 @@ def __init__(self, config: ConfigSchema) -> None:
self.nparts_rhs, self.rhs_unpartitioned_types, self.rhs_partitioned_types = get_partitioned_types( # noqa
config, Side.RHS
)
if self.nparts_lhs == 1 and self.nparts_rhs == 1:
assert (
config.num_machines == 1
), "Cannot run distributed training with a single partition."
self.lhs_partitioned_types = self.lhs_unpartitioned_types
self.rhs_partitioned_types = self.rhs_unpartitioned_types
self.lhs_unpartitioned_types = set()
self.rhs_unpartitioned_types = set()

self.unpartitioned_embeddings: Dict[EntityName, torch.nn.Parameter] = {}
self.partitioned_embeddings: Dict[
Tuple[EntityName, Partition], torch.nn.Parameter
Expand Down

0 comments on commit b7716f3

Please sign in to comment.