Skip to content

Commit

Permalink
Make LockServer consider graph structure and cardinalities for affinity
Browse files Browse the repository at this point in the history
Summary:
The LockServer, when assigning a new bucket to a trainer that was already holding one, used to pick the bucket with highest "affinity" to the previous one. This was defined very simply: try to keep the same left-hand side partition, if possible, and, if not, at least the right-hand side one.

This missed out on many subtleties and optimization that were possible with some graph structures. For example, it never considered swapping the left- and right-hand sides when the entities on both sides are the same. Or, if the right-hand side entities were the ones with the highest cardinality (and thus the more expensive to swap), it would still prefer to preserve the left-hand side ones.

Here I am fixing the above by making the affinity calculation aware of what entities are on which size and how many there are, so that the affinity reflects more closely the I/O cost incurred by swapping buckets.

Reviewed By: adamlerer

Differential Revision: D17830062

fbshipit-source-id: f5992d242e8189193137e8d869c6b655d696c409
  • Loading branch information
lw authored and facebook-github-bot committed Oct 15, 2019
1 parent 6b41eca commit c07ea71
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 17 deletions.
103 changes: 86 additions & 17 deletions torchbiggraph/bucket_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
import random
from abc import ABC, abstractmethod
from collections import defaultdict
from statistics import mean
from typing import Dict, List, NamedTuple, Optional, Set, Tuple

from torchbiggraph.config import BucketOrder
Expand Down Expand Up @@ -133,6 +135,22 @@ def create_buckets_ordered_by_affinity(
if nparts_lhs <= 0 or nparts_rhs <= 0:
return []

# TODO Change this function to use the same cost model as the LockServer
# when computing affinity (based on the number of entities to save and load)
# rather than just the number of partitions in common. Pay attention to keep
# the complexity of this algorithm linear in the number of buckets. This
# comment is too short to give a full description, but the idea is that only
# a few transitions are possible between a bucket and the next: the one that
# preserves all (ent, part) pairs, the one that preserves only the lhs ones,
# only the rhs ones, only the intersection of the two, or none at all. So we
# can keep a dict from sets of (ent, part) to lists of buckets, and insert
# each bucket into four of those lists, namely the ones for all its (ent,
# part), its lhs ones, its rhs ones and the intersection of its lhs and rhs
# ones. Then, when looking for the next bucket, we figure out the transition
# that is cheapest (among the options defined above), determine the set of
# (ent, part) we need to move to in order to achieve that transition type
# and we look up in the dict to find a bucket containing those (ent, part).

# This is our "source of truth" on what buckets we haven't outputted yet. It
# can be queried in constant time.
remaining: Set[Bucket] = set()
Expand Down Expand Up @@ -329,13 +347,24 @@ def __init__(
nparts_rhs: int,
entities_lhs: Set[EntityName],
entities_rhs: Set[EntityName],
entity_counts: Dict[str, List[int]],
init_tree: bool,
) -> None:
super().__init__(num_clients)
self.nparts_lhs: int = nparts_lhs
self.nparts_rhs: int = nparts_rhs
self.entities_lhs: Set[EntityName] = entities_lhs
self.entities_rhs: Set[EntityName] = entities_rhs
# We need the entity counts to estimate the I/O cost of switching from
# one bucket to another (due to saving and loading the checkpoints of
# the embeddings that are not in common). We don't want small variations
# in the embedding table sizes to always force a certain bucket order;
# instead we'd like to keep some randomness among the buckets that are
# effectively equivalent to ensure good mixing. So we replace the counts
# by the average of all the counts of a certain entity, as they all
# follow the same distribution.
self.average_entity_counts: Dict[str, int] = \
{entity: round(mean(counts)) for entity, counts in entity_counts.items()}
self.init_tree = init_tree

self.active: Dict[Bucket, Rank] = {}
Expand Down Expand Up @@ -390,6 +419,62 @@ def _is_initialized(self, bucket: Bucket) -> bool:
for entity in self.entities_rhs)
)

def _pick_bucket(
self,
buckets: List[Bucket],
maybe_old_bucket: Optional[Bucket],
) -> Bucket:
# We return a bucket (the lexicographically smallest one) among those
# that minimize the I/O cost of loading them (from scratch) or switching
# to them (if another bucket was already loaded). The cost of loading a
# bucket is the cost of loading all the embedding it needs, and is
# proportional to the number of entities it needs. The cost of switching
# buckets is the cost of storing the embeddings that were needed before
# but are not needed anymore and, conversely, loading the ones that
# weren't needed but now are: embeddings that were needed and still are
# come for free! Thus this function will tend to keep at least one side
# of the bucket in place and, depending on the "bipartitedness" of the
# graph (ranging from each entity appearing on only one side to all
# entities appearing on both sides), it may optimize further, for
# example by swapping the two sides. When loading from scratch, it will
# pick by preference diagonal buckets (i.e., of the form (i, i)), unless
# the graph is fully bipartite.
# TODO If init_tree is enabled, during the first pass we might want to
# _not_ choose diagonal buckets as they tend to keep the number of
# initialized embeddings small and thus delay the time at which other
# trainers can start working.
old_entities_parts: Set[Tuple[EntityName, Partition]] = set()
if maybe_old_bucket is not None:
old_entities_parts.update(
(entity, maybe_old_bucket.lhs) for entity in self.entities_lhs)
old_entities_parts.update(
(entity, maybe_old_bucket.rhs) for entity in self.entities_rhs)

buckets_by_cost: Dict[int, List[Bucket]] = defaultdict(list)
for bucket in buckets:
new_entities_parts: Set[Tuple[EntityName, Partition]] = set()
new_entities_parts.update(
(entity, bucket.lhs) for entity in self.entities_lhs)
new_entities_parts.update(
(entity, bucket.rhs) for entity in self.entities_rhs)

cost = sum(
self.average_entity_counts[entity]
for entity, _ in new_entities_parts.symmetric_difference(old_entities_parts))

buckets_by_cost[cost].append(bucket)

min_cost = min(buckets_by_cost.keys())
cheapest_buckets = buckets_by_cost[min_cost]

# TODO It may be interesting to get a random bucket among the acquirable
# ones (after filtering by highest affinity, if possible), rather than
# the lexicographically smallest one, to better mix the edges, but we
# should first empirically verify that this doesn't degrade accuracy.
# return random.choice(cheapest_buckets)

return cheapest_buckets[0]

def acquire_bucket(
self,
rank: Rank,
Expand Down Expand Up @@ -439,23 +524,7 @@ def acquire_bucket(
if len(acquirable_buckets) == 0:
return None, remaining

# TODO It may be interesting to get a random bucket among the acquirable
# ones (after filtering by highest affinity, if possible), rather than
# the lexicographically smallest one, to better mix the edges, but we
# should first empirically verify that this doesn't degrade accuracy.
# random.shuffle(acquirable_buckets)

if maybe_old_bucket is not None:
# The linter isn't smart enough to figure out that the closure is
# capturing a non-None value, thus alias it to a new variable, which
# will get a non-Optional type.
old_bucket = maybe_old_bucket
new_bucket = max(
acquirable_buckets,
key=lambda b: - (2 * (b.lhs == old_bucket.lhs)
+ (b.rhs == old_bucket.rhs)))
else:
new_bucket = acquirable_buckets[0]
new_bucket = self._pick_bucket(acquirable_buckets, maybe_old_bucket)

self.active[new_bucket] = rank
self.done.add(new_bucket)
Expand Down
1 change: 1 addition & 0 deletions torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def train_and_report_stats(
nparts_rhs=nparts_rhs,
entities_lhs=lhs_partitioned_types,
entities_rhs=rhs_partitioned_types,
entity_counts=entity_counts,
init_tree=config.distributed_tree_init_order,
),
process_name="LockServer",
Expand Down

0 comments on commit c07ea71

Please sign in to comment.