forked from facebookresearch/PyTorch-BigGraph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatching.py
136 lines (106 loc) · 4.28 KB
/
batching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE.txt file in the root directory of this source tree.
import time
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List, Optional
import torch
from torchbiggraph.edgelist import EdgeList
from torchbiggraph.model import MultiRelationEmbedder
from torchbiggraph.stats import Stats
from torchbiggraph.types import LongTensorType
def group_by_relation_type(edges: EdgeList) -> List[EdgeList]:
"""Split the edge list in groups that have the same relation type."""
if len(edges) == 0:
return []
if edges.has_scalar_relation_type():
return [edges]
# FIXME Is PyTorch's sort stable? Won't this risk messing up the random shuffle?
sorted_rel, order = edges.rel.sort()
delta = sorted_rel[1:] - sorted_rel[:-1]
cutpoints = (delta.nonzero().flatten() + 1).tolist()
result: List[EdgeList] = []
for start, end in zip([0] + cutpoints, cutpoints + [len(edges)]):
rel_type = sorted_rel[start]
edges_for_rel_type = edges[order[start:end]]
result.append(EdgeList(edges_for_rel_type.lhs,
edges_for_rel_type.rhs,
rel_type))
return result
def batch_edges_mix_relation_types(
edges: EdgeList,
*,
batch_size: int,
) -> Iterable[EdgeList]:
"""Split the edges in batches that can contain multiple relation types
The output preserves the input's order. Batches are all of the same size,
except possibly the last one.
"""
for offset in range(0, len(edges), batch_size):
yield edges[offset:offset + batch_size]
def batch_edges_group_by_relation_type(
edges: EdgeList,
*,
batch_size: int,
) -> Iterable[EdgeList]:
"""Split the edges in batches that each contain a single relation type
Batches are all of the same size, except possibly the last one for each
relation type.
"""
edge_groups = group_by_relation_type(edges)
num_edges_left_per_group = torch.tensor(
[len(edges) for edges in edge_groups], dtype=torch.long)
while num_edges_left_per_group.sum() > 0:
idx = int(torch.multinomial(num_edges_left_per_group.float(), 1))
edge_group = edge_groups[idx]
offset = len(edge_group) - int(num_edges_left_per_group[idx])
sub_edges = edge_group[offset:offset + batch_size]
yield sub_edges
num_edges_left_per_group[idx] -= len(sub_edges)
def call(f: Callable[[], Stats]) -> Stats:
"""Helper to be able to do pool.map(call, [partial(f, foo=42)])
Using pool.starmap(f, [(42,)]) is shorter, but it doesn't support keyword
arguments. It appears going through partial is the only way to do that.
"""
return f()
def process_in_batches(
batch_size: int,
model: MultiRelationEmbedder,
batch_processor: "AbstractBatchProcessor",
edges: EdgeList,
indices: Optional[LongTensorType] = None,
delay: float = 0.0,
) -> Stats:
"""Split lhs, rhs and rel in batches, process them and sum the stats
If indices is not None, only operate on x[indices] for x = lhs, rhs and rel.
If delay is positive, wait for that many seconds before starting.
"""
if indices is not None:
edges = edges[indices]
time.sleep(delay)
# FIXME: it's not really safe to do partial batches if num_batch_negs != 0
# because partial batches will produce incorrect results, and if the
# dataset per thread is very small then every batch may be partial. I don't
# know of a perfect solution for this that doesn't introduce other biases...
all_stats = []
if model.num_dynamic_rels > 0:
batcher = batch_edges_mix_relation_types
else:
batcher = batch_edges_group_by_relation_type
for batch_edges in batcher(edges, batch_size=batch_size):
all_stats.append(batch_processor.process_one_batch(model, batch_edges))
stats = Stats.sum(all_stats)
if indices is not None:
assert stats.count == indices.size(0)
return stats
class AbstractBatchProcessor(ABC):
@abstractmethod
def process_one_batch(
self,
model: MultiRelationEmbedder,
batch_edges: EdgeList,
) -> Stats:
pass