Skip to content

Commit

Permalink
Introduce EdgeList
Browse files Browse the repository at this point in the history
Summary: Couple lhs, rhs and rel together in what is essentially a dataclass with a couple of helper methods. This makes quite a bunch of interfaces a bit narrower. Also, it will help attach additional fields to edges with most code needed no futher changes.

Reviewed By: adamlerer

Differential Revision: D15109566

fbshipit-source-id: 7f9a42e9d7228560de9776634df7be17944b6d2c
  • Loading branch information
lw authored and facebook-github-bot committed May 1, 2019
1 parent 701ec17 commit 03571af
Show file tree
Hide file tree
Showing 15 changed files with 775 additions and 292 deletions.
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ jobs:
python3 tests/batching_tests.py
python3 tests/bucket_scheduling_tests.py
python3 tests/distributed_tests.py
python3 tests/edgelist_tests.py
python3 tests/entitylist_tests.py
python3 tests/fileio_tests.py
python3 tests/functional_tests.py
python3 tests/losses_tests.py
Expand Down
40 changes: 17 additions & 23 deletions examples/filtered_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Dict, List, Union, Tuple

import torch
from typing import Dict, List, Tuple

from torchbiggraph.config import ConfigSchema
from torchbiggraph.entitylist import EntityList
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.eval import RankingEvaluator
from torchbiggraph.fileio import EdgeReader
from torchbiggraph.model import Scores
from torchbiggraph.util import log
from torchbiggraph.stats import Stats
from torchbiggraph.types import Partition, LongTensorType
from torchbiggraph.types import Partition


class FilteredRankingEvaluator(RankingEvaluator):
Expand Down Expand Up @@ -49,14 +47,14 @@ def __init__(self, config: ConfigSchema, filter_paths: List[str]):
log("Building links map from path %s" % path)
e_reader = EdgeReader(path)
# Assume unpartitioned.
lhs, rhs, rel = e_reader.read(Partition(0), Partition(0))
num_edges = lhs.size(0)
for i in range(num_edges):
edges = e_reader.read(Partition(0), Partition(0))
for idx in range(len(edges)):
# Assume non-featurized.
cur_lhs = lhs.to_tensor()[i].item()
cur_rel = rel[i].item()
cur_lhs = int(edges.lhs.to_tensor()[idx])
# Assume dynamic relations.
cur_rel = int(edges.rel[idx])
# Assume non-featurized.
cur_rhs = rhs.to_tensor()[i].item()
cur_rhs = int(edges.rhs.to_tensor()[idx])

self.lhs_map[cur_lhs, cur_rel].append(cur_rhs)
self.rhs_map[cur_rhs, cur_rel].append(cur_lhs)
Expand All @@ -66,21 +64,17 @@ def __init__(self, config: ConfigSchema, filter_paths: List[str]):
def eval(
self,
scores: Scores,
batch_lhs: EntityList,
batch_rhs: EntityList,
batch_rel: Union[int, LongTensorType],
batch_edges: EdgeList,
) -> Stats:
# Assume dynamic relations.
assert isinstance(batch_rel, torch.LongTensor)

_, _, lhs_neg_scores, rhs_neg_scores = scores
b = batch_lhs.size(0)
for idx in range(b):

for idx in range(len(batch_edges)):
# Assume non-featurized.
cur_lhs = batch_lhs.to_tensor()[idx].item()
cur_rel = batch_rel[idx].item()
cur_lhs = int(batch_edges.lhs.to_tensor()[idx])
# Assume dynamic relations.
cur_rel = int(batch_edges.rel[idx])
# Assume non-featurized.
cur_rhs = batch_rhs.to_tensor()[idx].item()
cur_rhs = int(batch_edges.rhs.to_tensor()[idx])

rhs_edges_filtered = self.lhs_map[cur_lhs, cur_rel]
lhs_edges_filtered = self.rhs_map[cur_rhs, cur_rel]
Expand All @@ -93,4 +87,4 @@ def eval(
lhs_neg_scores[idx][lhs_edges_filtered] = -1e9
rhs_neg_scores[idx][rhs_edges_filtered] = -1e9

return super().eval(scores, batch_lhs, batch_rhs, batch_rel)
return super().eval(scores, batch_edges)
196 changes: 98 additions & 98 deletions tests/batching_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch

from torchbiggraph.edgelist import EdgeList
from torchbiggraph.entitylist import EntityList
from torchbiggraph.batching import (
group_by_relation_type,
Expand All @@ -25,63 +26,69 @@ class TestGroupByRelationType(TestCase):
def test_basic(self):
self.assertEqual(
group_by_relation_type(
torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)),
EdgeList(
EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)),
torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
),
),
[
(
EdgeList(
EntityList.from_tensor(
torch.tensor([24, 13, 77, 38], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([75, 9, 49, 64], dtype=torch.long)),
0,
torch.tensor(0, dtype=torch.long),
),
(
EdgeList(
EntityList.from_tensor(
torch.tensor([93, 31], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([90, 25], dtype=torch.long)),
1,
torch.tensor(1, dtype=torch.long),
),
(
EdgeList(
EntityList.from_tensor(
torch.tensor([70, 66, 5, 5], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([23, 31, 42, 50], dtype=torch.long)),
2,
torch.tensor(2, dtype=torch.long),
),
],
)

def test_constant(self):
self.assertEqual(
group_by_relation_type(
torch.tensor([3, 3, 3, 3], dtype=torch.long),
EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25], dtype=torch.long)),
EdgeList(
EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25], dtype=torch.long)),
torch.tensor([3, 3, 3, 3], dtype=torch.long),
),
),
[
(
EdgeList(
EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25], dtype=torch.long)),
3,
torch.tensor(3, dtype=torch.long),
),
],
)

def test_empty(self):
self.assertEqual(
group_by_relation_type(
torch.empty((0,), dtype=torch.long),
EntityList.empty(),
EntityList.empty(),
EdgeList(
EntityList.empty(),
EntityList.empty(),
torch.empty((0,), dtype=torch.long),
),
),
[],
)
Expand All @@ -90,92 +97,85 @@ def test_empty(self):
class TestBatchEdgesMixRelationTypes(TestCase):

def test_basic(self):
actual_batches = batch_edges_mix_relation_types(
EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)),
torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
batch_size=4,
self.assertEqual(
list(batch_edges_mix_relation_types(
EdgeList(
EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)),
torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
),
batch_size=4,
)),
[
EdgeList(
EntityList.from_tensor(
torch.tensor([93, 24, 13, 31], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([90, 75, 9, 25], dtype=torch.long)),
torch.tensor([1, 0, 0, 1], dtype=torch.long),
),
EdgeList(
EntityList.from_tensor(
torch.tensor([70, 66, 77, 38], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([23, 31, 49, 64], dtype=torch.long)),
torch.tensor([2, 2, 0, 0], dtype=torch.long),
),
EdgeList(
EntityList.from_tensor(
torch.tensor([5, 5], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([42, 50], dtype=torch.long)),
torch.tensor([2, 2], dtype=torch.long),
),
],
)
expected_batches = [
(
EntityList.from_tensor(
torch.tensor([93, 24, 13, 31], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([90, 75, 9, 25], dtype=torch.long)),
torch.tensor([1, 0, 0, 1], dtype=torch.long),
),
(
EntityList.from_tensor(
torch.tensor([70, 66, 77, 38], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([23, 31, 49, 64], dtype=torch.long)),
torch.tensor([2, 2, 0, 0], dtype=torch.long),
),
(
EntityList.from_tensor(
torch.tensor([5, 5], dtype=torch.long)),
EntityList.from_tensor(
torch.tensor([42, 50], dtype=torch.long)),
torch.tensor([2, 2], dtype=torch.long),
),
]
# We can't use assertEqual because == between tensors doesn't work.
for actual_batch, expected_batch \
in zip_longest(actual_batches, expected_batches):
a_lhs, a_rhs, a_rel = actual_batch
e_lhs, e_rhs, e_rel = expected_batch
self.assertEqual(a_lhs, e_lhs)
self.assertEqual(a_rhs, e_rhs)
self.assertTrue(torch.equal(a_rel, e_rel), "%s != %s" % (a_rel, e_rel))


class TestBatchEdgesGroupByType(TestCase):

def test_basic(self):
lhs = EntityList.from_tensor(torch.tensor(
[93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long))
rhs = EntityList.from_tensor(torch.tensor(
[90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long))
rel = torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long)
lhs_by_type = defaultdict(list)
rhs_by_type = defaultdict(list)
for batch_lhs, batch_rhs, rel_type in batch_edges_group_by_relation_type(
lhs, rhs, rel, batch_size=3
):
self.assertIsInstance(batch_lhs, EntityList)
self.assertLessEqual(batch_lhs.size(0), 3)
lhs_by_type[rel_type].append(batch_lhs)
self.assertIsInstance(batch_rhs, EntityList)
self.assertLessEqual(batch_rhs.size(0), 3)
rhs_by_type[rel_type].append(batch_rhs)
self.assertCountEqual(lhs_by_type.keys(), [0, 1, 2])
self.assertCountEqual(rhs_by_type.keys(), [0, 1, 2])
self.assertEqual(
EntityList.cat(lhs_by_type[0]),
EntityList.from_tensor(torch.tensor(
[24, 13, 77, 38], dtype=torch.long)))
self.assertEqual(
EntityList.cat(rhs_by_type[0]),
edges = EdgeList(
EntityList.from_tensor(torch.tensor(
[75, 9, 49, 64], dtype=torch.long)))
self.assertEqual(
EntityList.cat(lhs_by_type[1]),
EntityList.from_tensor(torch.tensor(
[93, 31], dtype=torch.long)))
self.assertEqual(
EntityList.cat(rhs_by_type[1]),
EntityList.from_tensor(torch.tensor(
[90, 25], dtype=torch.long)))
self.assertEqual(
EntityList.cat(lhs_by_type[2]),
[93, 24, 13, 31, 70, 66, 77, 38, 5, 5], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[70, 66, 5, 5], dtype=torch.long)))
[90, 75, 9, 25, 23, 31, 49, 64, 42, 50], dtype=torch.long)),
torch.tensor([1, 0, 0, 1, 2, 2, 0, 0, 2, 2], dtype=torch.long),
)
edges_by_type = defaultdict(list)
for batch_edges in batch_edges_group_by_relation_type(edges, batch_size=3):
self.assertIsInstance(batch_edges, EdgeList)
self.assertLessEqual(len(batch_edges), 3)
self.assertTrue(batch_edges.has_scalar_relation_type())
edges_by_type[batch_edges.get_relation_type_as_scalar()].append(batch_edges)
self.assertEqual(
EntityList.cat(rhs_by_type[2]),
EntityList.from_tensor(torch.tensor(
[23, 31, 42, 50], dtype=torch.long)))
{k: EdgeList.cat(v) for k, v in edges_by_type.items()},
{
0: EdgeList(
EntityList.from_tensor(torch.tensor(
[24, 13, 77, 38], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[75, 9, 49, 64], dtype=torch.long)),
torch.tensor(0, dtype=torch.long),
),
1: EdgeList(
EntityList.from_tensor(torch.tensor(
[93, 31], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[90, 25], dtype=torch.long)),
torch.tensor(1, dtype=torch.long),
),
2: EdgeList(
EntityList.from_tensor(torch.tensor(
[70, 66, 5, 5], dtype=torch.long)),
EntityList.from_tensor(torch.tensor(
[23, 31, 42, 50], dtype=torch.long)),
torch.tensor(2, dtype=torch.long),
),
},
)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 03571af

Please sign in to comment.