Skip to content

Commit

Permalink
Merge pull request wouterkool#25 from wouterkool/v4
Browse files Browse the repository at this point in the history
Update to Python 3.8, PyTorch 1.7, some bug fixes
  • Loading branch information
wouterkool authored Nov 20, 2020
2 parents ffd5b86 + 5fa0b17 commit 21e5a00
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 84 deletions.
19 changes: 9 additions & 10 deletions nets/attention_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ class AttentionModelFixed(NamedTuple):
logit_key: torch.Tensor

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice):
return AttentionModelFixed(
node_embeddings=self.node_embeddings[key],
context_node_projected=self.context_node_projected[key],
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
logit_key=self.logit_key[key]
)
return super(AttentionModelFixed, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice)
return AttentionModelFixed(
node_embeddings=self.node_embeddings[key],
context_node_projected=self.context_node_projected[key],
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
logit_key=self.logit_key[key]
)


class AttentionModel(nn.Module):
Expand Down Expand Up @@ -172,7 +171,7 @@ def propose_expansions(self, beam, fixed, expand_size=None, normalize=False, max
flat_feas = flat_score > -1e10 # != -math.inf triggers

# Parent is row idx of ind_topk, can be found by enumerating elements and dividing by number of columns
flat_parent = torch.arange(flat_action.size(-1), out=flat_action.new()) / ind_topk.size(-1)
flat_parent = torch.arange(flat_action.size(-1), out=flat_action.new()) // ind_topk.size(-1)

# Filter infeasible
feas_ind_2d = torch.nonzero(flat_feas)
Expand Down
15 changes: 11 additions & 4 deletions nets/graph_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ def __init__(
self,
n_heads,
input_dim,
embed_dim=None,
embed_dim,
val_dim=None,
key_dim=None
):
super(MultiHeadAttention, self).__init__()

if val_dim is None:
assert embed_dim is not None, "Provide either embed_dim or val_dim"
val_dim = embed_dim // n_heads
if key_dim is None:
key_dim = val_dim
Expand All @@ -43,8 +42,7 @@ def __init__(
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))

if embed_dim is not None:
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim))

self.init_parameters()

Expand Down Expand Up @@ -109,6 +107,15 @@ def forward(self, q, h=None, mask=None):
self.W_out.view(-1, self.embed_dim)
).view(batch_size, n_query, self.embed_dim)

# Alternative:
# headst = heads.transpose(0, 1) # swap the dimensions for batch and heads to align it for the matmul
# # proj_h = torch.einsum('bhni,hij->bhnj', headst, self.W_out)
# projected_heads = torch.matmul(headst, self.W_out)
# out = torch.sum(projected_heads, dim=1) # sum across heads

# Or:
# out = torch.einsum('hbni,hij->bnj', heads, self.W_out)

return out


Expand Down
19 changes: 9 additions & 10 deletions problems/op/state_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def dist(self):
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
cur_total_prize=self.cur_total_prize[key],
)
return super(StateOP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
cur_total_prize=self.cur_total_prize[key],
)

# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
# def __len__(self):
Expand Down
21 changes: 10 additions & 11 deletions problems/pctsp/state_pctsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@ def dist(self):
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_total_prize=self.cur_total_prize[key],
cur_total_penalty=self.cur_total_penalty[key],
cur_coord=self.cur_coord[key],
)
return super(StatePCTSP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_total_prize=self.cur_total_prize[key],
cur_total_penalty=self.cur_total_penalty[key],
cur_coord=self.cur_coord[key],
)

# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
# def __len__(self):
Expand Down
19 changes: 9 additions & 10 deletions problems/tsp/state_tsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@ def visited(self):
return mask_long2bool(self.visited_, n=self.loc.size(-2))

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
first_a=self.first_a[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
)
return super(StateTSP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
first_a=self.first_a[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
)

@staticmethod
def initialize(loc, visited_dtype=torch.uint8):
Expand Down
19 changes: 9 additions & 10 deletions problems/vrp/state_cvrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@ def dist(self):
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)
return super(StateCVRP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)

# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
# def __len__(self):
Expand Down
19 changes: 9 additions & 10 deletions problems/vrp/state_sdvrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ class StateSDVRP(NamedTuple):
VEHICLE_CAPACITY = 1.0 # Hardcoded

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
demands_with_depot=self.demands_with_depot[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)
return super(StateSDVRP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
demands_with_depot=self.demands_with_depot[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)

@staticmethod
def initialize(input):
Expand Down
35 changes: 16 additions & 19 deletions utils/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,14 @@ def ids(self):
return self.state.ids.view(-1) # Need to flat as state has steps dimension

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
# ids=self.ids[key],
score=self.score[key] if self.score is not None else None,
state=self.state[key],
parent=self.parent[key] if self.parent is not None else None,
action=self.action[key] if self.action is not None else None
)
return super(BatchBeam, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
# ids=self.ids[key],
score=self.score[key] if self.score is not None else None,
state=self.state[key],
parent=self.parent[key] if self.parent is not None else None,
action=self.action[key] if self.action is not None else None
)

# Do not use __len__ since this is used by namedtuple internally and should be number of fields
# def __len__(self):
Expand Down Expand Up @@ -207,15 +206,13 @@ def __getitem__(self, key):
assert not isinstance(key, slice), "CachedLookup does not support slicing, " \
"you can slice the result of an index operation instead"

if torch.is_tensor(key): # If tensor, idx all tensors by this tensor:

if self.key is None:
self.key = key
self.current = self.orig[key]
elif len(key) != len(self.key) or (key != self.key).any():
self.key = key
self.current = self.orig[key]
assert torch.is_tensor(key) # If tensor, idx all tensors by this tensor:

return self.current
if self.key is None:
self.key = key
self.current = self.orig[key]
elif len(key) != len(self.key) or (key != self.key).any():
self.key = key
self.current = self.orig[key]

return super(CachedLookup, self).__getitem__(key)
return self.current

0 comments on commit 21e5a00

Please sign in to comment.