Skip to content

Commit

Permalink
Half-precision just for MM
Browse files Browse the repository at this point in the history
Summary: Use half-precision in a limited way; should only help with compute time for high-number-negatives regime.

Reviewed By: lw

Differential Revision: D20464627

fbshipit-source-id: 0be6377c99ca213928ae39de24a4f0880840a58d
  • Loading branch information
adamlerer authored and facebook-github-bot committed Mar 18, 2020
1 parent c7dea47 commit bbf1e7a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
4 changes: 4 additions & 0 deletions torchbiggraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ class ConfigSchema(Schema):
"this to a value around 16 typically increases "
"communication bandwidth."},
)
half_precision: bool = attr.ib(
default=False,
metadata={'help': "Use half-precision training (GPU ONLY)"},
)

# Additional global validation.

Expand Down
17 changes: 13 additions & 4 deletions torchbiggraph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def forward(
match_shape(rhs_neg, num_chunks, -1, dim)

# Equivalent to (but faster than) torch.einsum('cid,cid->ci', ...).
pos_scores = (lhs_pos * rhs_pos).sum(-1)
pos_scores = (lhs_pos.float() * rhs_pos.float()).sum(-1)
# Equivalent to (but faster than) torch.einsum('cid,cjd->cij', ...).
lhs_neg_scores = torch.bmm(rhs_pos, lhs_neg.transpose(-1, -2))
rhs_neg_scores = torch.bmm(lhs_pos, rhs_neg.transpose(-1, -2))
Expand Down Expand Up @@ -588,7 +588,7 @@ def forward(
match_shape(rhs_neg, num_chunks, -1, dim)

# Equivalent to (but faster than) torch.einsum('cid,cid->ci', ...).
pos_scores = (lhs_pos * rhs_pos).sum(-1)
pos_scores = (lhs_pos.float() * rhs_pos.float()).sum(-1)
# Equivalent to (but faster than) torch.einsum('cid,cjd->cij', ...).
lhs_neg_scores = torch.bmm(rhs_pos, lhs_neg.transpose(-1, -2))
rhs_neg_scores = torch.bmm(lhs_pos, rhs_neg.transpose(-1, -2))
Expand Down Expand Up @@ -654,7 +654,7 @@ def forward(
match_shape(rhs_neg, num_chunks, -1, dim)

# Smaller distances are higher scores, so take their negatives.
pos_scores = (lhs_pos - rhs_pos).pow_(2).sum(dim=-1).clamp_min_(1e-30).sqrt_().neg()
pos_scores = (lhs_pos.float() - rhs_pos.float()).pow_(2).sum(dim=-1).clamp_min_(1e-30).sqrt_().neg()
lhs_neg_scores = batched_all_pairs_l2_dist(rhs_pos, lhs_neg).neg()
rhs_neg_scores = batched_all_pairs_l2_dist(lhs_pos, rhs_neg).neg()

Expand Down Expand Up @@ -683,7 +683,7 @@ def forward(
match_shape(rhs_neg, num_chunks, -1, dim)

# Smaller distances are higher scores, so take their negatives.
pos_scores = (lhs_pos - rhs_pos).pow_(2).sum(dim=-1).neg()
pos_scores = (lhs_pos.float() - rhs_pos.float()).pow_(2).sum(dim=-1).neg()
lhs_neg_scores = batched_all_pairs_squared_l2_dist(rhs_pos, lhs_neg).neg()
rhs_neg_scores = batched_all_pairs_squared_l2_dist(lhs_pos, rhs_neg).neg()

Expand Down Expand Up @@ -789,6 +789,7 @@ def __init__(
global_emb: bool = False,
max_norm: Optional[float] = None,
num_dynamic_rels: int = 0,
half_precision: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -822,6 +823,7 @@ def __init__(
self.global_embs: Optional[nn.ParameterDict] = None

self.max_norm: Optional[float] = max_norm
self.half_precision = half_precision

def set_embeddings(self, entity: str, side: Side, weights: nn.Parameter) -> None:
if self.entities[entity].featurized:
Expand Down Expand Up @@ -874,6 +876,8 @@ def adjust_embs(
# 3. Prepare for the comparator.
embs = self.comparator.prepare(embs)

if self.half_precision:
embs = embs.half()
return embs

def prepare_negatives(
Expand Down Expand Up @@ -1164,6 +1168,10 @@ def forward_direction_agnostic(
pos_scores, src_neg_scores, dst_neg_scores = \
self.comparator(src_pos, dst_pos, src_neg, dst_neg)

pos_scores = pos_scores.float()
src_neg_scores = src_neg_scores.float()
dst_neg_scores = dst_neg_scores.float()

# The masks tell us which negative scores (i.e., scores for non-existing
# edges) must be ignored because they come from pairs we don't actually
# intend to compare (say, positive pairs or interactions with padding).
Expand Down Expand Up @@ -1236,6 +1244,7 @@ def make_model(config: ConfigSchema) -> MultiRelationEmbedder:
global_emb=config.global_emb,
max_norm=config.max_norm,
num_dynamic_rels=num_dynamic_rels,
half_precision=config.half_precision,
)


Expand Down

0 comments on commit bbf1e7a

Please sign in to comment.