From 83ff522cbb86e61d3667f9ed2c441296def247b9 Mon Sep 17 00:00:00 2001 From: Erik Brinkman Date: Tue, 10 Mar 2020 18:04:15 -0700 Subject: [PATCH] Allow arbitrary dimensions for pbg Summary: This is a fairly minimal patch to allow for arbitrary dimensions in pbg Note it does not translate to filament, i.e. filament still requires everything have the same dimension. It's possible that I've missed places for verification that things are the correct side, and it's not clear to me how well these will play with gpu filament Reviewed By: lw Differential Revision: D20316816 fbshipit-source-id: 77de40b86a68bb36441ecfea73677d5baf6f8d7a --- torchbiggraph/config.py | 5 +++++ torchbiggraph/model.py | 34 ++++++++++++++++++++-------------- torchbiggraph/train.py | 6 ++++-- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/torchbiggraph/config.py b/torchbiggraph/config.py index d804d1cb..fe0d337f 100644 --- a/torchbiggraph/config.py +++ b/torchbiggraph/config.py @@ -69,6 +69,11 @@ class EntitySchema(Schema): metadata={'help': "Whether the entities of this type are represented " "as sets of features."}, ) + dimension: Optional[int] = attr.ib( + default=None, + validator=optional(positive), + metadata={'help': "Override the default dimension for this entity."} + ) @schema diff --git a/torchbiggraph/model.py b/torchbiggraph/model.py index b5403cbb..fb1fc354 100644 --- a/torchbiggraph/model.py +++ b/torchbiggraph/model.py @@ -767,7 +767,7 @@ class MultiRelationEmbedder(nn.Module): def __init__( self, - dim: int, + default_dim: int, relations: List[RelationSchema], entities: Dict[str, EntitySchema], num_batch_negs: int, @@ -783,8 +783,6 @@ def __init__( ) -> None: super().__init__() - self.dim: int = dim - self.relations: List[RelationSchema] = relations self.entities: Dict[str, EntitySchema] = entities self.num_dynamic_rels: int = num_dynamic_rels @@ -806,10 +804,11 @@ def __init__( self.rhs_embs: nn.ParameterDict = nn.ModuleDict() if global_emb: - self.global_embs: Optional[nn.ParameterDict] = nn.ParameterDict() - for entity in entities.keys(): - self.global_embs[self.EMB_PREFIX + entity] = \ - nn.Parameter(torch.zeros((dim,))) + global_embs = nn.ParameterDict() + for entity, entity_schema in entities.items(): + global_embs[self.EMB_PREFIX + entity] = \ + nn.Parameter(torch.zeros((entity_schema.dimension or default_dim,))) + self.global_embs = global_embs else: self.global_embs: Optional[nn.ParameterDict] = None @@ -1126,12 +1125,17 @@ def forward_direction_agnostic( dst_pos = self.adjust_embs(dst_pos, rel, dst_entity_type, dst_operator) num_chunks = ceil_of_ratio(num_pos, chunk_size) + src_dim = src_pos.size(-1) + dst_dim = dst_pos.size(-1) if num_pos < num_chunks * chunk_size: - padding = torch.zeros(()).expand((num_chunks * chunk_size - num_pos, self.dim)) - src_pos = torch.cat((src_pos, padding), dim=0) - dst_pos = torch.cat((dst_pos, padding), dim=0) - src_pos = src_pos.view((num_chunks, chunk_size, self.dim)) - dst_pos = dst_pos.view((num_chunks, chunk_size, self.dim)) + src_padding = torch.zeros(()).expand( + (num_chunks * chunk_size - num_pos, src_dim)) + src_pos = torch.cat((src_pos, src_padding), dim=0) + dst_padding = torch.zeros(()).expand( + (num_chunks * chunk_size - num_pos, dst_dim)) + dst_pos = torch.cat((dst_pos, dst_padding), dim=0) + src_pos = src_pos.view((num_chunks, chunk_size, src_dim)) + dst_pos = dst_pos.view((num_chunks, chunk_size, dst_dim)) src_neg, src_ignore_mask = self.prepare_negatives( src, src_pos, src_module, src_negative_sampling_method, @@ -1189,9 +1193,11 @@ def make_model(config: ConfigSchema) -> MultiRelationEmbedder: rhs_operators: List[Optional[Union[AbstractOperator, AbstractDynamicOperator]]] = [] for r in config.relations: lhs_operators.append( - instantiate_operator(r.operator, Side.LHS, num_dynamic_rels, config.dimension)) + instantiate_operator(r.operator, Side.LHS, num_dynamic_rels, + config.entities[r.lhs].dimension or config.dimension)) rhs_operators.append( - instantiate_operator(r.operator, Side.RHS, num_dynamic_rels, config.dimension)) + instantiate_operator(r.operator, Side.RHS, num_dynamic_rels, + config.entities[r.rhs].dimension or config.dimension)) comparator_class = COMPARATORS.get_class(config.comparator) comparator = comparator_class() diff --git a/torchbiggraph/train.py b/torchbiggraph/train.py index c4d448b3..7e7c3b7c 100644 --- a/torchbiggraph/train.py +++ b/torchbiggraph/train.py @@ -506,8 +506,10 @@ def load_embeddings( if embs is None and loadpath_manager is not None: embs, optim_state = loadpath_manager.maybe_read(entity, part) if embs is None: - embs, optim_state = init_embs(entity, entity_counts[entity][part], - config.dimension, config.init_scale) + embs, optim_state = init_embs( + entity, entity_counts[entity][part], + config.entities[entity].dimension or config.dimension, + config.init_scale) assert embs.is_shared() embs = torch.nn.Parameter(embs) optimizer = make_optimizer([embs], True)