From 399720009515b957e2304be4385bea3886720435 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Thu, 5 Dec 2019 01:56:14 -0800 Subject: [PATCH] Simpler partition swapping Summary: I always had the feeling that `swap_partitioned_embeddings` was more complicated than it had to be, and I realized that it is due to the fact that in addition to the saving+loading logic, it also needs to take care of converting between the (entity_type, part) identifiers used by the checkpoints and the (entity_type, side) identifiers used by the model. In this diff I'm adding an extra level of indirection (the solution to all problems!) which fixes this by storing the partitions in a so-called holder rather than in the model itself, and doing the conversion outside of `swap_partitioned_embeddings`. Separation between model and holder also means that we can load and drop embeddings on the model as we want, without that requiring checkpointing or loading from disk. Thus we reset the model (remove all embeddings) at the end of each pass, so that all its parameters are the global parameters. This simplifies some code that had to filter out the unpartitioned embeddings when checkpointing the global params. Reviewed By: adamlerer Differential Revision: D18373175 fbshipit-source-id: dd32b6ca009066ce8a565ddc471618d1941860df --- torchbiggraph/eval.py | 43 +++---- torchbiggraph/model.py | 40 +++--- torchbiggraph/parameter_sharing.py | 8 -- torchbiggraph/train.py | 197 ++++++++++++----------------- torchbiggraph/util.py | 27 ++-- 5 files changed, 142 insertions(+), 173 deletions(-) diff --git a/torchbiggraph/eval.py b/torchbiggraph/eval.py index 36e37a33..05bc4fa3 100644 --- a/torchbiggraph/eval.py +++ b/torchbiggraph/eval.py @@ -11,7 +11,7 @@ import time from functools import partial from itertools import chain -from typing import Callable, Generator, List, Optional, Tuple +from typing import Callable, Dict, Generator, List, Optional, Tuple import torch @@ -31,11 +31,11 @@ from torchbiggraph.stats import Stats, average_of_sums from torchbiggraph.types import Bucket, EntityName, Partition, Side from torchbiggraph.util import ( + EmbeddingHolder, compute_randomized_auc, create_pool, get_async_result, get_num_workers, - get_partitioned_types, set_logging_verbosity, setup_logging, split_almost_equally, @@ -115,8 +115,7 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: assert embs.is_shared() return torch.nn.Parameter(embs) - nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) - nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) + holder = EmbeddingHolder(config) num_workers = get_num_workers(config.workers) pool = create_pool( @@ -135,11 +134,9 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: model.eval() - for entity, econfig in config.entities.items(): - if econfig.num_partitions == 1: - embs = load_embeddings(entity, Partition(0)) - model.set_embeddings(entity, embs, Side.LHS) - model.set_embeddings(entity, embs, Side.RHS) + for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types: + embs = load_embeddings(entity, Partition(0)) + holder.unpartitioned_embeddings[entity] = embs all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): @@ -149,22 +146,22 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: edge_storage = EDGE_STORAGES.make_instance(edge_path) all_edge_path_stats = [] - last_lhs, last_rhs = None, None - for bucket in create_buckets_ordered_lexicographically(nparts_lhs, nparts_rhs): + # FIXME This order assumes higher affinity on the left-hand side, as it's + # the one changing more slowly. Make this adaptive to the actual affinity. + for bucket in create_buckets_ordered_lexicographically(holder.nparts_lhs, holder.nparts_rhs): tic = time.time() # logger.info(f"{bucket}: Loading entities") - if last_lhs != bucket.lhs: - for e in lhs_partitioned_types: - model.clear_embeddings(e, Side.LHS) - embs = load_embeddings(e, bucket.lhs) - model.set_embeddings(e, embs, Side.LHS) - if last_rhs != bucket.rhs: - for e in rhs_partitioned_types: - model.clear_embeddings(e, Side.RHS) - embs = load_embeddings(e, bucket.rhs) - model.set_embeddings(e, embs, Side.RHS) - last_lhs, last_rhs = bucket.lhs, bucket.rhs + old_parts = set(holder.partitioned_embeddings.keys()) + new_parts = {(e, bucket.lhs) for e in holder.lhs_partitioned_types} \ + | {(e, bucket.rhs) for e in holder.rhs_partitioned_types} + for entity, part in old_parts - new_parts: + del holder.partitioned_embeddings[entity, part] + for entity, part in new_parts - old_parts: + embs = load_embeddings(entity, part) + holder.partitioned_embeddings[entity, part] = embs + + model.set_all_embeddings(holder, bucket) # logger.info(f"{bucket}: Loading edges") edges = edge_storage.load_edges(bucket.lhs, bucket.rhs) @@ -199,6 +196,8 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, " f"bucket {bucket}: {mean_bucket_stats}") + model.clear_all_embeddings() + yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) diff --git a/torchbiggraph/model.py b/torchbiggraph/model.py index ec4e6396..b5403cbb 100644 --- a/torchbiggraph/model.py +++ b/torchbiggraph/model.py @@ -34,8 +34,8 @@ from torchbiggraph.graph_storages import RELATION_TYPE_STORAGES from torchbiggraph.plugin import PluginRegistry from torchbiggraph.tensorlist import TensorList -from torchbiggraph.types import FloatTensorType, LongTensorType, Side -from torchbiggraph.util import CouldNotLoadData +from torchbiggraph.types import Bucket, FloatTensorType, LongTensorType, Side +from torchbiggraph.util import CouldNotLoadData, EmbeddingHolder logger = logging.getLogger("torchbiggraph") @@ -815,28 +815,32 @@ def __init__( self.max_norm: Optional[float] = max_norm - def set_embeddings(self, entity: str, weights: nn.Parameter, side: Side): + def set_embeddings(self, entity: str, side: Side, weights: nn.Parameter) -> None: if self.entities[entity].featurized: emb = FeaturizedEmbedding(weights, max_norm=self.max_norm) else: emb = SimpleEmbedding(weights, max_norm=self.max_norm) side.pick(self.lhs_embs, self.rhs_embs)[self.EMB_PREFIX + entity] = emb - def clear_embeddings(self, entity: str, side: Side) -> None: - embs = side.pick(self.lhs_embs, self.rhs_embs) - try: - del embs[self.EMB_PREFIX + entity] - except KeyError: - pass - - def get_embeddings(self, entity: str, side: Side) -> nn.Parameter: - embs = side.pick(self.lhs_embs, self.rhs_embs) - try: - emb = embs[self.EMB_PREFIX + entity] - except KeyError: - return None - else: - return emb.weight + def set_all_embeddings(self, holder: EmbeddingHolder, bucket: Bucket) -> None: + # This could be a method of the EmbeddingHolder, but it's here as + # utils.py cannot depend on model.py. + for entity in holder.lhs_unpartitioned_types: + self.set_embeddings( + entity, Side.LHS, holder.unpartitioned_embeddings[entity]) + for entity in holder.rhs_unpartitioned_types: + self.set_embeddings( + entity, Side.RHS, holder.unpartitioned_embeddings[entity]) + for entity in holder.lhs_partitioned_types: + self.set_embeddings( + entity, Side.LHS, holder.partitioned_embeddings[entity, bucket.lhs]) + for entity in holder.rhs_partitioned_types: + self.set_embeddings( + entity, Side.RHS, holder.partitioned_embeddings[entity, bucket.rhs]) + + def clear_all_embeddings(self) -> None: + self.lhs_embs.clear() + self.rhs_embs.clear() def adjust_embs( self, diff --git a/torchbiggraph/parameter_sharing.py b/torchbiggraph/parameter_sharing.py index 02918677..e8a7189b 100644 --- a/torchbiggraph/parameter_sharing.py +++ b/torchbiggraph/parameter_sharing.py @@ -499,11 +499,3 @@ def join(self) -> None: self.q.put(('join', None)) self.check() self.p.join() - - def share_model_params(self, model: nn.Module) -> None: - shared_parameters: Set[int] = set() - for k, v in ModuleStateDict(model.state_dict()).items(): - if v._cdata not in shared_parameters: - shared_parameters.add(v._cdata) - logger.info(f"Adding {k} ({v.numel()} params) to parameter server") - self.set_param(k, v.data) diff --git a/torchbiggraph/train.py b/torchbiggraph/train.py index 44de5c09..d1a506d0 100644 --- a/torchbiggraph/train.py +++ b/torchbiggraph/train.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from functools import partial from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple import torch import torch.distributed as td @@ -73,11 +73,11 @@ from torchbiggraph.util import ( BucketLogger, DummyOptimizer, + EmbeddingHolder, create_pool, fast_approx_rand, get_async_result, get_num_workers, - get_partitioned_types, hide_distributed_logging, round_up_to_nearest_multiple, set_logging_verbosity, @@ -98,14 +98,15 @@ class Trainer(AbstractBatchProcessor): def __init__( self, - global_optimizer: Optimizer, + model_optimizer: Optimizer, loss_fn: str, margin: float, relations: List[RelationSchema], ) -> None: super().__init__() - self.global_optimizer = global_optimizer - self.entity_optimizers: Dict[Tuple[EntityName, Partition], Optimizer] = {} + self.model_optimizer = model_optimizer + self.unpartitioned_optimizers: Dict[EntityName, Optimizer] = {} + self.partitioned_optimizers: Dict[Tuple[EntityName, Partition], Optimizer] = {} loss_fn_class = LOSS_FUNCTIONS.get_class(loss_fn) # TODO This is awful! Can we do better? @@ -139,8 +140,10 @@ def process_one_batch( count=len(batch_edges)) loss.backward() - self.global_optimizer.step(closure=None) - for optimizer in self.entity_optimizers.values(): + self.model_optimizer.step(closure=None) + for optimizer in self.unpartitioned_optimizers.values(): + optimizer.step(closure=None) + for optimizer in self.partitioned_optimizers.values(): optimizer.step(closure=None) return stats @@ -287,8 +290,6 @@ def should_preserve_old_checkpoint( def get_num_edge_chunks( edge_paths: List[str], - nparts_lhs: int, - nparts_rhs: int, max_edges_per_chunk: int, ) -> int: max_edges_per_bucket = 0 @@ -335,12 +336,11 @@ def train_and_report_stats( entity_counts[entity].append(entity_storage.load_count(entity, part)) # Figure out how many lhs and rhs partitions we need - nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) - nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) + holder = EmbeddingHolder(config) logger.debug( - f"nparts {nparts_lhs} {nparts_rhs} " - f"types {lhs_partitioned_types} {rhs_partitioned_types}") - total_buckets = nparts_lhs * nparts_rhs + f"nparts {holder.nparts_lhs} {holder.nparts_rhs} " + f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}") + total_buckets = holder.nparts_lhs * holder.nparts_rhs sync: AbstractSynchronizer bucket_scheduler: AbstractBucketScheduler @@ -360,10 +360,10 @@ def train_and_report_stats( start_server( LockServer( num_clients=len(ranks.trainers), - nparts_lhs=nparts_lhs, - nparts_rhs=nparts_rhs, - entities_lhs=lhs_partitioned_types, - entities_rhs=rhs_partitioned_types, + nparts_lhs=holder.nparts_lhs, + nparts_rhs=holder.nparts_rhs, + entities_lhs=holder.lhs_partitioned_types, + entities_rhs=holder.rhs_partitioned_types, entity_counts=entity_counts, init_tree=config.distributed_tree_init_order, ), @@ -429,7 +429,7 @@ def train_and_report_stats( else: sync = DummySynchronizer() bucket_scheduler = SingleMachineBucketScheduler( - nparts_lhs, nparts_rhs, config.bucket_order) + holder.nparts_lhs, holder.nparts_rhs, config.bucket_order) parameter_sharer = None partition_client = None hide_distributed_logging() @@ -478,7 +478,7 @@ def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimi num_edge_chunks = config.num_edge_chunks else: num_edge_chunks = get_num_edge_chunks( - config.edge_paths, nparts_lhs, nparts_rhs, config.max_edges_per_chunk) + config.edge_paths, config.max_edges_per_chunk) iteration_manager = IterationManager( config.num_epochs, config.edge_paths, num_edge_chunks, iteration_idx=checkpoint_manager.checkpoint_version) @@ -494,7 +494,7 @@ def load_embeddings( part: Partition, strict: bool = False, force_dirty: bool = False, - ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]: + ) -> Tuple[torch.nn.Parameter, Adagrad]: if strict: embs, optim_state = checkpoint_manager.read(entity, part, force_dirty=force_dirty) @@ -510,7 +510,10 @@ def load_embeddings( embs, optim_state = init_embs(entity, entity_counts[entity][part], config.dimension, config.init_scale) assert embs.is_shared() - return torch.nn.Parameter(embs), optim_state + optimizer = make_optimizer([embs], True) + if optim_state is not None: + optimizer.load_state_dict(optim_state) + return torch.nn.Parameter(embs), optimizer logger.info("Initializing global model...") @@ -519,7 +522,7 @@ def load_embeddings( model.share_memory() if trainer is None: trainer = Trainer( - global_optimizer=make_optimizer(model.parameters(), False), + model_optimizer=make_optimizer(model.parameters(), False), loss_fn=config.loss_fn, margin=config.margin, relations=config.relations, @@ -538,22 +541,28 @@ def load_embeddings( if state_dict is not None: model.load_state_dict(state_dict, strict=False) if optim_state is not None: - trainer.global_optimizer.load_state_dict(optim_state) + trainer.model_optimizer.load_state_dict(optim_state) logger.debug("Loading unpartitioned entities...") - for entity, econfig in config.entities.items(): - if econfig.num_partitions == 1: - embs, optim_state = load_embeddings(entity, Partition(0)) - model.set_embeddings(entity, embs, Side.LHS) - model.set_embeddings(entity, embs, Side.RHS) - optimizer = make_optimizer([embs], True) - if optim_state is not None: - optimizer.load_state_dict(optim_state) - trainer.entity_optimizers[(entity, Partition(0))] = optimizer + for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types: + embs, optimizer = load_embeddings(entity, Partition(0)) + holder.unpartitioned_embeddings[entity] = embs + trainer.unpartitioned_optimizers[entity] = optimizer # start communicating shared parameters with the parameter server if parameter_sharer is not None: - parameter_sharer.share_model_params(model) + shared_parameters: Set[int] = set() + for name, param in model.named_parameters(): + if id(param) in shared_parameters: + continue + shared_parameters.add(id(param)) + key = f"model.{name}" + logger.info(f"Adding {key} ({param.numel()} params) to parameter server") + parameter_sharer.set_param(key, param.data) + for entity, embs in holder.unpartitioned_embeddings.items(): + key = f"entity.{entity}" + logger.info(f"Adding {key} ({embs.numel()} params) to parameter server") + parameter_sharer.set_param(key, embs.data) strict = False @@ -561,90 +570,51 @@ def swap_partitioned_embeddings( old_b: Optional[Bucket], new_b: Optional[Bucket], old_stats: Optional[BucketStats], - ): - # 0. given the old and new buckets, construct data structures to keep - # track of old and new embedding (entity, part) tuples - + ) -> int: io_bytes = 0 logger.info(f"Swapping partitioned embeddings {old_b} {new_b}") - types = ([(e, Side.LHS) for e in lhs_partitioned_types] - + [(e, Side.RHS) for e in rhs_partitioned_types]) - old_parts = {(e, old_b.get_partition(side)): side - for e, side in types if old_b is not None} - new_parts = {(e, new_b.get_partition(side)): side - for e, side in types if new_b is not None} + old_parts: Set[Tuple[EntityName, Partition]] = set() + if old_b is not None: + old_parts.update((e, old_b.lhs) for e in holder.lhs_partitioned_types) + old_parts.update((e, old_b.rhs) for e in holder.rhs_partitioned_types) + new_parts: Set[Tuple[EntityName, Partition]] = set() + if new_b is not None: + new_parts.update((e, new_b.lhs) for e in holder.lhs_partitioned_types) + new_parts.update((e, new_b.rhs) for e in holder.rhs_partitioned_types) - to_checkpoint = set(old_parts) - set(new_parts) - preserved = set(old_parts) & set(new_parts) + assert old_parts == holder.partitioned_embeddings.keys() - # 1. checkpoint embeddings that will not be used in the next pair - # - if old_b is not None: # there are previous embeddings to checkpoint + if old_b is not None: if old_stats is None: raise TypeError("Got old bucket but not its stats") - logger.info("Writing partitioned embeddings") - for entity, part in to_checkpoint: - side = old_parts[(entity, part)] - side_name = side.pick("lhs", "rhs") - logger.debug(f"Checkpointing ({entity} {part} {side_name})") - embs = model.get_embeddings(entity, side) - optim_key = (entity, part) - optim_state = OptimizerStateDict(trainer.entity_optimizers[optim_key].state_dict()) + logger.info("Saving partitioned embeddings to checkpoint") + for entity, part in old_parts - new_parts: + logger .debug(f"Saving ({entity} {part})") + embs = holder.partitioned_embeddings.pop((entity, part)) + optimizer = trainer.partitioned_optimizers.pop((entity, part)) + checkpoint_manager.write( + entity, part, + embs.detach(), OptimizerStateDict(optimizer.state_dict())) io_bytes += embs.numel() * embs.element_size() # ignore optim state - checkpoint_manager.write(entity, part, embs.detach(), optim_state) - if optim_key in trainer.entity_optimizers: - del trainer.entity_optimizers[optim_key] # these variables are holding large objects; let them be freed del embs - del optim_state + del optimizer bucket_scheduler.release_bucket(old_b, old_stats) - # 2. copy old embeddings that will be used in the next pair - # into a temporary dictionary - # - tmp_emb = {x: model.get_embeddings(x[0], old_parts[x]) for x in preserved} - - for entity, _ in types: - model.clear_embeddings(entity, Side.LHS) - model.clear_embeddings(entity, Side.RHS) - - if new_b is None: # there are no new embeddings to load - return io_bytes - - bucket_logger = BucketLogger(logger, bucket=new_b) - - # 3. load new embeddings into the model/optimizer, either from disk - # or the temporary dictionary - # - bucket_logger.info("Loading entities") - for entity, side in types: - part = new_b.get_partition(side) - part_key = (entity, part) - if part_key in tmp_emb: - bucket_logger.debug(f"Loading ({entity}, {part}) from preserved") - embs, optim_state = tmp_emb[part_key], None - else: - bucket_logger.debug(f"Loading ({entity}, {part})") - + if new_b is not None: + logger.info("Loading partitioned embeddings from checkpoint") + for entity, part in new_parts - old_parts: + logger.debug(f"Loading ({entity} {part})") force_dirty = bucket_scheduler.check_and_set_dirty(entity, part) - embs, optim_state = load_embeddings( + embs, optimizer = load_embeddings( entity, part, strict=strict, force_dirty=force_dirty) + holder.partitioned_embeddings[entity, part] = embs + trainer.partitioned_optimizers[entity, part] = optimizer io_bytes += embs.numel() * embs.element_size() # ignore optim state - model.set_embeddings(entity, embs, side) - tmp_emb[part_key] = embs - - optim_key = (entity, part) - if optim_key not in trainer.entity_optimizers: - bucket_logger.debug(f"Resetting optimizer {optim_key}") - optimizer = make_optimizer([embs], True) - if optim_state is not None: - bucket_logger.debug("Setting optim state") - optimizer.load_state_dict(optim_state) - - trainer.entity_optimizers[optim_key] = optimizer + assert new_parts == holder.partitioned_embeddings.keys() return io_bytes @@ -701,6 +671,8 @@ def swap_partitioned_embeddings( io_bytes += swap_partitioned_embeddings(old_b, cur_b, old_stats) + model.set_all_embeddings(holder, cur_b) + current_index = \ (iteration_manager.iteration_idx + 1) * total_buckets - remaining @@ -710,9 +682,9 @@ def swap_partitioned_embeddings( checkpoint_manager.wait_for_marker(current_index - 1) bucket_logger.debug("Prefetching") - for entity in lhs_partitioned_types: + for entity in holder.lhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.lhs) - for entity in rhs_partitioned_types: + for entity in holder.rhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.rhs) checkpoint_manager.record_marker(current_index) @@ -812,6 +784,8 @@ def swap_partitioned_embeddings( eval_stats_after = Stats.sum(all_eval_stats_after).average() bucket_logger.info(f"Stats after training: {eval_stats_after}") + model.clear_all_embeddings() + yield current_index, eval_stats_before, stats, eval_stats_after cur_stats = BucketStats( @@ -846,25 +820,16 @@ def swap_partitioned_embeddings( if rank == 0: for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: - embs = model.get_embeddings(entity, Side.LHS) - optimizer = trainer.entity_optimizers[(entity, Partition(0))] - + embs = holder.unpartitioned_embeddings[entity] + optimizer = trainer.unpartitioned_optimizers[entity] checkpoint_manager.write( entity, Partition(0), embs.detach(), OptimizerStateDict(optimizer.state_dict())) - sanitized_state_dict: ModuleStateDict = {} - for k, v in ModuleStateDict(model.state_dict()).items(): - if k.startswith('lhs_embs') or k.startswith('rhs_embs'): - # skipping state that's an entity embedding - continue - sanitized_state_dict[k] = v - logger.info("Writing the metadata") + state_dict: ModuleStateDict = ModuleStateDict(model.state_dict()) checkpoint_manager.write_model( - sanitized_state_dict, - OptimizerStateDict(trainer.global_optimizer.state_dict()), - ) + state_dict, OptimizerStateDict(trainer.model_optimizer.state_dict())) logger.info("Writing the training stats") all_stats_dicts: List[Dict[...]] = [] diff --git a/torchbiggraph/util.py b/torchbiggraph/util.py index b252c62f..e95af459 100644 --- a/torchbiggraph/util.py +++ b/torchbiggraph/util.py @@ -21,7 +21,7 @@ from torch.optim import Optimizer from torchbiggraph.config import ConfigSchema -from torchbiggraph.types import Bucket, EntityName, FloatTensorType, Side +from torchbiggraph.types import Bucket, EntityName, FloatTensorType, Partition, Side logger = logging.getLogger("torchbiggraph") @@ -260,16 +260,16 @@ def get_async_result( def get_partitioned_types( config: ConfigSchema, side: Side, -) -> Tuple[int, Set[EntityName]]: - """Return the number of partitions on a given side and the partitioned entity types +) -> Tuple[int, Set[EntityName], Set[EntityName]]: + """Return the number of partitions on a given side and the (un-)partitioned entity types Each of the entity types that appear on the given side (LHS or RHS) of a relation type is split into some number of partitions. The ones that are split into one partition are called "unpartitioned" and behave as if all of their entities belonged to all buckets. The other ones are the "properly" partitioned ones. Currently, they must all be partitioned into the same number of partitions. This - function returns that number and the names of the properly partitioned entity - types. + function returns that number, the names of the unpartitioned entity types and the + names of the properly partitioned entity types. """ entity_names_by_num_parts: Dict[int, Set[EntityName]] = defaultdict(set) @@ -278,17 +278,26 @@ def get_partitioned_types( entity_config = config.entities[entity_name] entity_names_by_num_parts[entity_config.num_partitions].add(entity_name) - if 1 in entity_names_by_num_parts: - del entity_names_by_num_parts[1] + unpartitioned_entity_names = entity_names_by_num_parts.pop(1, set()) if len(entity_names_by_num_parts) == 0: - return 1, set() + return 1, unpartitioned_entity_names, set() if len(entity_names_by_num_parts) > 1: raise RuntimeError("Currently num_partitions must be a single " "value across all partitioned entities.") (num_partitions, partitioned_entity_names), = entity_names_by_num_parts.items() - return num_partitions, partitioned_entity_names + return num_partitions, unpartitioned_entity_names, partitioned_entity_names + + +class EmbeddingHolder: + def __init__(self, config: ConfigSchema) -> None: + self.nparts_lhs, self.lhs_unpartitioned_types, self.lhs_partitioned_types = \ + get_partitioned_types(config, Side.LHS) + self.nparts_rhs, self.rhs_unpartitioned_types, self.rhs_partitioned_types = \ + get_partitioned_types(config, Side.RHS) + self.unpartitioned_embeddings: Dict[EntityName, torch.nn.Parameter] = {} + self.partitioned_embeddings: Dict[Tuple[EntityName, Partition], torch.nn.Parameter] = {} # compute a randomized AUC using a fixed number of sample points