Skip to content

Commit

Permalink
Fix regularizers on GPU, fix test (facebookresearch#182)
Browse files Browse the repository at this point in the history
Summary:
Fix regularizers on GPU.

- [ ] Docs change / refactoring / dependency upgrade
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

## Motivation and Context / Related issue

facebookresearch#178

## How Has This Been Tested (if it applies)

Fixed integration tests to catch the error. The GPU test was already running with regularization, but there were no operators so it wasn't getting caught.

## Checklist

- [x] The documentation is up-to-date with the changes I made.
- [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: facebookresearch#182

Reviewed By: lw

Differential Revision: D25930089

Pulled By: adamlerer

fbshipit-source-id: 76e4138b6a84021f87bb0afc5d3303ff0cabe9cf
  • Loading branch information
adamlerer authored and facebook-github-bot committed Jan 19, 2021
1 parent dfd307b commit 4571dee
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,9 @@ def test_gpu_1partition(self):

def _test_gpu(self, do_half_precision=False, num_partitions=2):
entity_name = "e"
relation_config = RelationSchema(name="r", lhs=entity_name, rhs=entity_name)
relation_config = RelationSchema(
name="r", lhs=entity_name, rhs=entity_name, operator="complex_diagonal"
)
base_config = ConfigSchema(
dimension=16,
batch_size=1024,
Expand Down
4 changes: 3 additions & 1 deletion torchbiggraph/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def forward_dynamic(
total = 0
operator_params = operator.get_operator_params_for_reg(rel_idxs)
if operator_params is not None:
total += torch.sum(operator_params ** 3).to(src_pos.device)
operator_params = operator_params.to(src_pos.device)
total += torch.sum(operator_params ** 3)
for x in (src_pos, dst_pos):
total += torch.sum(operator.prepare_embs_for_reg(x) ** 3)
total *= self.weight
Expand All @@ -87,6 +88,7 @@ def forward(
total = 0
operator_params = operator.get_operator_params_for_reg()
if operator_params is not None:
operator_params = operator_params.to(src_pos.device)
batch_size = len(src_pos)
total += torch.sum(operator_params ** 3) * batch_size
for x in (src_pos, dst_pos):
Expand Down

0 comments on commit 4571dee

Please sign in to comment.