Skip to content

Commit

Permalink
Allow arbitrary dimensions for pbg
Browse files Browse the repository at this point in the history
Summary:
This is a fairly minimal patch to allow for arbitrary dimensions in pbg

Note it does not translate to filament, i.e. filament still requires everything have the same dimension.

It's possible that I've missed places for verification that things are the correct side, and it's not clear to me how well these will play with gpu filament

Reviewed By: lw

Differential Revision: D20316816

fbshipit-source-id: 77de40b86a68bb36441ecfea73677d5baf6f8d7a
  • Loading branch information
Erik Brinkman authored and facebook-github-bot committed Mar 11, 2020
1 parent 332a31a commit 83ff522
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
5 changes: 5 additions & 0 deletions torchbiggraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class EntitySchema(Schema):
metadata={'help': "Whether the entities of this type are represented "
"as sets of features."},
)
dimension: Optional[int] = attr.ib(
default=None,
validator=optional(positive),
metadata={'help': "Override the default dimension for this entity."}
)


@schema
Expand Down
34 changes: 20 additions & 14 deletions torchbiggraph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ class MultiRelationEmbedder(nn.Module):

def __init__(
self,
dim: int,
default_dim: int,
relations: List[RelationSchema],
entities: Dict[str, EntitySchema],
num_batch_negs: int,
Expand All @@ -783,8 +783,6 @@ def __init__(
) -> None:
super().__init__()

self.dim: int = dim

self.relations: List[RelationSchema] = relations
self.entities: Dict[str, EntitySchema] = entities
self.num_dynamic_rels: int = num_dynamic_rels
Expand All @@ -806,10 +804,11 @@ def __init__(
self.rhs_embs: nn.ParameterDict = nn.ModuleDict()

if global_emb:
self.global_embs: Optional[nn.ParameterDict] = nn.ParameterDict()
for entity in entities.keys():
self.global_embs[self.EMB_PREFIX + entity] = \
nn.Parameter(torch.zeros((dim,)))
global_embs = nn.ParameterDict()
for entity, entity_schema in entities.items():
global_embs[self.EMB_PREFIX + entity] = \
nn.Parameter(torch.zeros((entity_schema.dimension or default_dim,)))
self.global_embs = global_embs
else:
self.global_embs: Optional[nn.ParameterDict] = None

Expand Down Expand Up @@ -1126,12 +1125,17 @@ def forward_direction_agnostic(
dst_pos = self.adjust_embs(dst_pos, rel, dst_entity_type, dst_operator)

num_chunks = ceil_of_ratio(num_pos, chunk_size)
src_dim = src_pos.size(-1)
dst_dim = dst_pos.size(-1)
if num_pos < num_chunks * chunk_size:
padding = torch.zeros(()).expand((num_chunks * chunk_size - num_pos, self.dim))
src_pos = torch.cat((src_pos, padding), dim=0)
dst_pos = torch.cat((dst_pos, padding), dim=0)
src_pos = src_pos.view((num_chunks, chunk_size, self.dim))
dst_pos = dst_pos.view((num_chunks, chunk_size, self.dim))
src_padding = torch.zeros(()).expand(
(num_chunks * chunk_size - num_pos, src_dim))
src_pos = torch.cat((src_pos, src_padding), dim=0)
dst_padding = torch.zeros(()).expand(
(num_chunks * chunk_size - num_pos, dst_dim))
dst_pos = torch.cat((dst_pos, dst_padding), dim=0)
src_pos = src_pos.view((num_chunks, chunk_size, src_dim))
dst_pos = dst_pos.view((num_chunks, chunk_size, dst_dim))

src_neg, src_ignore_mask = self.prepare_negatives(
src, src_pos, src_module, src_negative_sampling_method,
Expand Down Expand Up @@ -1189,9 +1193,11 @@ def make_model(config: ConfigSchema) -> MultiRelationEmbedder:
rhs_operators: List[Optional[Union[AbstractOperator, AbstractDynamicOperator]]] = []
for r in config.relations:
lhs_operators.append(
instantiate_operator(r.operator, Side.LHS, num_dynamic_rels, config.dimension))
instantiate_operator(r.operator, Side.LHS, num_dynamic_rels,
config.entities[r.lhs].dimension or config.dimension))
rhs_operators.append(
instantiate_operator(r.operator, Side.RHS, num_dynamic_rels, config.dimension))
instantiate_operator(r.operator, Side.RHS, num_dynamic_rels,
config.entities[r.rhs].dimension or config.dimension))

comparator_class = COMPARATORS.get_class(config.comparator)
comparator = comparator_class()
Expand Down
6 changes: 4 additions & 2 deletions torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,10 @@ def load_embeddings(
if embs is None and loadpath_manager is not None:
embs, optim_state = loadpath_manager.maybe_read(entity, part)
if embs is None:
embs, optim_state = init_embs(entity, entity_counts[entity][part],
config.dimension, config.init_scale)
embs, optim_state = init_embs(
entity, entity_counts[entity][part],
config.entities[entity].dimension or config.dimension,
config.init_scale)
assert embs.is_shared()
embs = torch.nn.Parameter(embs)
optimizer = make_optimizer([embs], True)
Expand Down

0 comments on commit 83ff522

Please sign in to comment.