Skip to content

Commit

Permalink
Simpler partition swapping
Browse files Browse the repository at this point in the history
Summary:
I always had the feeling that `swap_partitioned_embeddings` was more complicated than it had to be, and I realized that it is due to the fact that in addition to the saving+loading logic, it also needs to take care of converting between the (entity_type, part) identifiers used by the checkpoints and the (entity_type, side) identifiers used by the model.

In this diff I'm adding an extra level of indirection (the solution to all problems!) which fixes this by storing the partitions in a so-called holder rather than in the model itself, and doing the conversion outside of `swap_partitioned_embeddings`.

Separation between model and holder also means that we can load and drop embeddings on the model as we want, without that requiring checkpointing or loading from disk. Thus we reset the model (remove all embeddings) at the end of each pass, so that all its parameters are the global parameters. This simplifies some code that had to filter out the unpartitioned embeddings when checkpointing the global params.

Reviewed By: adamlerer

Differential Revision: D18373175

fbshipit-source-id: dd32b6ca009066ce8a565ddc471618d1941860df
  • Loading branch information
lw authored and facebook-github-bot committed Dec 5, 2019
1 parent fe6c371 commit 3997200
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 173 deletions.
43 changes: 21 additions & 22 deletions torchbiggraph/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from functools import partial
from itertools import chain
from typing import Callable, Generator, List, Optional, Tuple
from typing import Callable, Dict, Generator, List, Optional, Tuple

import torch

Expand All @@ -31,11 +31,11 @@
from torchbiggraph.stats import Stats, average_of_sums
from torchbiggraph.types import Bucket, EntityName, Partition, Side
from torchbiggraph.util import (
EmbeddingHolder,
compute_randomized_auc,
create_pool,
get_async_result,
get_num_workers,
get_partitioned_types,
set_logging_verbosity,
setup_logging,
split_almost_equally,
Expand Down Expand Up @@ -115,8 +115,7 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
assert embs.is_shared()
return torch.nn.Parameter(embs)

nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)
holder = EmbeddingHolder(config)

num_workers = get_num_workers(config.workers)
pool = create_pool(
Expand All @@ -135,11 +134,9 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:

model.eval()

for entity, econfig in config.entities.items():
if econfig.num_partitions == 1:
embs = load_embeddings(entity, Partition(0))
model.set_embeddings(entity, embs, Side.LHS)
model.set_embeddings(entity, embs, Side.RHS)
for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
embs = load_embeddings(entity, Partition(0))
holder.unpartitioned_embeddings[entity] = embs

all_stats: List[Stats] = []
for edge_path_idx, edge_path in enumerate(config.edge_paths):
Expand All @@ -149,22 +146,22 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
edge_storage = EDGE_STORAGES.make_instance(edge_path)

all_edge_path_stats = []
last_lhs, last_rhs = None, None
for bucket in create_buckets_ordered_lexicographically(nparts_lhs, nparts_rhs):
# FIXME This order assumes higher affinity on the left-hand side, as it's
# the one changing more slowly. Make this adaptive to the actual affinity.
for bucket in create_buckets_ordered_lexicographically(holder.nparts_lhs, holder.nparts_rhs):
tic = time.time()
# logger.info(f"{bucket}: Loading entities")

if last_lhs != bucket.lhs:
for e in lhs_partitioned_types:
model.clear_embeddings(e, Side.LHS)
embs = load_embeddings(e, bucket.lhs)
model.set_embeddings(e, embs, Side.LHS)
if last_rhs != bucket.rhs:
for e in rhs_partitioned_types:
model.clear_embeddings(e, Side.RHS)
embs = load_embeddings(e, bucket.rhs)
model.set_embeddings(e, embs, Side.RHS)
last_lhs, last_rhs = bucket.lhs, bucket.rhs
old_parts = set(holder.partitioned_embeddings.keys())
new_parts = {(e, bucket.lhs) for e in holder.lhs_partitioned_types} \
| {(e, bucket.rhs) for e in holder.rhs_partitioned_types}
for entity, part in old_parts - new_parts:
del holder.partitioned_embeddings[entity, part]
for entity, part in new_parts - old_parts:
embs = load_embeddings(entity, part)
holder.partitioned_embeddings[entity, part] = embs

model.set_all_embeddings(holder, bucket)

# logger.info(f"{bucket}: Loading edges")
edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
Expand Down Expand Up @@ -199,6 +196,8 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, "
f"bucket {bucket}: {mean_bucket_stats}")

model.clear_all_embeddings()

yield edge_path_idx, bucket, mean_bucket_stats

total_edge_path_stats = Stats.sum(all_edge_path_stats)
Expand Down
40 changes: 22 additions & 18 deletions torchbiggraph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from torchbiggraph.graph_storages import RELATION_TYPE_STORAGES
from torchbiggraph.plugin import PluginRegistry
from torchbiggraph.tensorlist import TensorList
from torchbiggraph.types import FloatTensorType, LongTensorType, Side
from torchbiggraph.util import CouldNotLoadData
from torchbiggraph.types import Bucket, FloatTensorType, LongTensorType, Side
from torchbiggraph.util import CouldNotLoadData, EmbeddingHolder


logger = logging.getLogger("torchbiggraph")
Expand Down Expand Up @@ -815,28 +815,32 @@ def __init__(

self.max_norm: Optional[float] = max_norm

def set_embeddings(self, entity: str, weights: nn.Parameter, side: Side):
def set_embeddings(self, entity: str, side: Side, weights: nn.Parameter) -> None:
if self.entities[entity].featurized:
emb = FeaturizedEmbedding(weights, max_norm=self.max_norm)
else:
emb = SimpleEmbedding(weights, max_norm=self.max_norm)
side.pick(self.lhs_embs, self.rhs_embs)[self.EMB_PREFIX + entity] = emb

def clear_embeddings(self, entity: str, side: Side) -> None:
embs = side.pick(self.lhs_embs, self.rhs_embs)
try:
del embs[self.EMB_PREFIX + entity]
except KeyError:
pass

def get_embeddings(self, entity: str, side: Side) -> nn.Parameter:
embs = side.pick(self.lhs_embs, self.rhs_embs)
try:
emb = embs[self.EMB_PREFIX + entity]
except KeyError:
return None
else:
return emb.weight
def set_all_embeddings(self, holder: EmbeddingHolder, bucket: Bucket) -> None:
# This could be a method of the EmbeddingHolder, but it's here as
# utils.py cannot depend on model.py.
for entity in holder.lhs_unpartitioned_types:
self.set_embeddings(
entity, Side.LHS, holder.unpartitioned_embeddings[entity])
for entity in holder.rhs_unpartitioned_types:
self.set_embeddings(
entity, Side.RHS, holder.unpartitioned_embeddings[entity])
for entity in holder.lhs_partitioned_types:
self.set_embeddings(
entity, Side.LHS, holder.partitioned_embeddings[entity, bucket.lhs])
for entity in holder.rhs_partitioned_types:
self.set_embeddings(
entity, Side.RHS, holder.partitioned_embeddings[entity, bucket.rhs])

def clear_all_embeddings(self) -> None:
self.lhs_embs.clear()
self.rhs_embs.clear()

def adjust_embs(
self,
Expand Down
8 changes: 0 additions & 8 deletions torchbiggraph/parameter_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,3 @@ def join(self) -> None:
self.q.put(('join', None))
self.check()
self.p.join()

def share_model_params(self, model: nn.Module) -> None:
shared_parameters: Set[int] = set()
for k, v in ModuleStateDict(model.state_dict()).items():
if v._cdata not in shared_parameters:
shared_parameters.add(v._cdata)
logger.info(f"Adding {k} ({v.numel()} params) to parameter server")
self.set_param(k, v.data)
Loading

0 comments on commit 3997200

Please sign in to comment.