Skip to content

Commit

Permalink
Move matrixtree test from separate main into tests. (OpenNMT#1242)
Browse files Browse the repository at this point in the history
  • Loading branch information
flauted authored and vince62s committed Feb 1, 2019
1 parent 6f11f29 commit 7732f27
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
11 changes: 1 addition & 10 deletions onmt/modules/structured_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch.nn as nn
import torch
import torch.cuda
from onmt.utils.logging import init_logger


class MatrixTree(nn.Module):
Expand All @@ -22,7 +21,7 @@ def forward(self, input):
output = input.clone()
for b in range(input.size(0)):
lap = laplacian[b].masked_fill(
torch.eye(input.size(1)).cuda().ne(0), 0)
torch.eye(input.size(1), device=input.device).ne(0), 0)
lap = -lap + torch.diag(lap.sum(0))
# store roots on diagonal
lap[0] = input[b].diag().exp()
Expand All @@ -39,11 +38,3 @@ def forward(self, input):
inv_laplacian.transpose(0, 1)[0])
output[b] = output[b] + torch.diag(roots_output)
return output


if __name__ == "__main__":
logger = init_logger('StructuredAttention.log')
dtree = MatrixTree()
q = torch.rand(1, 5, 5).cuda()
marg = dtree.forward(q)
logger.info(marg.sum(1))
13 changes: 13 additions & 0 deletions onmt/tests/test_structured_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import unittest
from onmt.modules.structured_attention import MatrixTree

import torch


class TestStructuredAttention(unittest.TestCase):
def test_matrix_tree_marg_pdfs_sum_to_1(self):
dtree = MatrixTree()
q = torch.rand(1, 5, 5)
marg = dtree.forward(q)
self.assertTrue(
marg.sum(1).allclose(torch.tensor(1.0)))

0 comments on commit 7732f27

Please sign in to comment.