forked from wouterkool/attention-learn-to-route
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathboolmask.py
68 lines (52 loc) · 2.53 KB
/
boolmask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn.functional as F
def _pad_mask(mask):
# By taking -size % 8, we get 0 if exactly divisible by 8
# and required padding otherwise (i.e. -1 % 8 = 7 pad)
pad = -mask.size(-1) % 8
if pad != 0:
mask = F.pad(mask, [0, pad])
return mask, mask.size(-1) // 8
def _mask_bool2byte(mask):
assert mask.dtype == torch.uint8
# assert (mask <= 1).all() # Precondition, disabled for efficiency
mask, d = _pad_mask(mask)
return (mask.view(*mask.size()[:-1], d, 8) << torch.arange(8, out=mask.new())).sum(-1, dtype=torch.uint8)
def _mask_byte2long(mask):
assert mask.dtype == torch.uint8
mask, d = _pad_mask(mask)
# Note this corresponds to a temporary factor 8
# memory overhead by converting to long before summing
# Alternatively, aggregate using for loop
return (mask.view(*mask.size()[:-1], d, 8).long() << (torch.arange(8, dtype=torch.int64, device=mask.device) * 8)).sum(-1)
def mask_bool2long(mask):
assert mask.dtype == torch.uint8
return _mask_byte2long(_mask_bool2byte(mask))
def _mask_long2byte(mask, n=None):
if n is None:
n = 8 * mask.size(-1)
return (mask[..., None] >> (torch.arange(8, out=mask.new()) * 8))[..., :n].to(torch.uint8).view(*mask.size()[:-1], -1)[..., :n]
def _mask_byte2bool(mask, n=None):
if n is None:
n = 8 * mask.size(-1)
return (mask[..., None] & (mask.new_ones(8) << torch.arange(8, out=mask.new()) * 1)).view(*mask.size()[:-1], -1)[..., :n] > 0
def mask_long2bool(mask, n=None):
assert mask.dtype == torch.int64
return _mask_byte2bool(_mask_long2byte(mask), n=n)
def mask_long_scatter(mask, values, check_unset=True):
"""
Sets values in mask in dimension -1 with arbitrary batch dimensions
If values contains -1, nothing is set
Note: does not work for setting multiple values at once (like normal scatter)
"""
assert mask.size()[:-1] == values.size()
rng = torch.arange(mask.size(-1), out=mask.new())
values_ = values[..., None] # Need to broadcast up do mask dim
# This indicates in which value of the mask a bit should be set
where = (values_ >= (rng * 64)) & (values_ < ((rng + 1) * 64))
# Optional: check that bit is not already set
assert not (check_unset and ((mask & (where.long() << (values_ % 64))) > 0).any())
# Set bit by shifting a 1 to the correct position
# (% not strictly necessary as bitshift is cyclic)
# since where is 0 if no value needs to be set, the bitshift has no effect
return mask | (where.long() << (values_ % 64))