From 6c5b78ec74ae007b9bedb3b36c30c84db4b12002 Mon Sep 17 00:00:00 2001 From: Wouter Kool Date: Fri, 20 Nov 2020 13:58:52 +0100 Subject: [PATCH 1/4] Remove super call in __getitem__ which causes problems in Python 3.8 --- nets/attention_model.py | 17 ++++++++--------- problems/op/state_op.py | 19 +++++++++---------- problems/pctsp/state_pctsp.py | 21 ++++++++++----------- problems/tsp/state_tsp.py | 19 +++++++++---------- problems/vrp/state_cvrp.py | 19 +++++++++---------- problems/vrp/state_sdvrp.py | 19 +++++++++---------- utils/beam_search.py | 35 ++++++++++++++++------------------- 7 files changed, 70 insertions(+), 79 deletions(-) diff --git a/nets/attention_model.py b/nets/attention_model.py index c64395c7..4fda7e01 100644 --- a/nets/attention_model.py +++ b/nets/attention_model.py @@ -29,15 +29,14 @@ class AttentionModelFixed(NamedTuple): logit_key: torch.Tensor def __getitem__(self, key): - if torch.is_tensor(key) or isinstance(key, slice): - return AttentionModelFixed( - node_embeddings=self.node_embeddings[key], - context_node_projected=self.context_node_projected[key], - glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads - glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads - logit_key=self.logit_key[key] - ) - return super(AttentionModelFixed, self).__getitem__(key) + assert torch.is_tensor(key) or isinstance(key, slice) + return AttentionModelFixed( + node_embeddings=self.node_embeddings[key], + context_node_projected=self.context_node_projected[key], + glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads + glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads + logit_key=self.logit_key[key] + ) class AttentionModel(nn.Module): diff --git a/problems/op/state_op.py b/problems/op/state_op.py index 5d41416c..c1c6db8a 100644 --- a/problems/op/state_op.py +++ b/problems/op/state_op.py @@ -36,16 +36,15 @@ def dist(self): return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1) def __getitem__(self, key): - if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor: - return self._replace( - ids=self.ids[key], - prev_a=self.prev_a[key], - visited_=self.visited_[key], - lengths=self.lengths[key], - cur_coord=self.cur_coord[key], - cur_total_prize=self.cur_total_prize[key], - ) - return super(StateOP, self).__getitem__(key) + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + return self._replace( + ids=self.ids[key], + prev_a=self.prev_a[key], + visited_=self.visited_[key], + lengths=self.lengths[key], + cur_coord=self.cur_coord[key], + cur_total_prize=self.cur_total_prize[key], + ) # Warning: cannot override len of NamedTuple, len should be number of fields, not batch size # def __len__(self): diff --git a/problems/pctsp/state_pctsp.py b/problems/pctsp/state_pctsp.py index 02f2f590..9555dad2 100644 --- a/problems/pctsp/state_pctsp.py +++ b/problems/pctsp/state_pctsp.py @@ -36,17 +36,16 @@ def dist(self): return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1) def __getitem__(self, key): - if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor: - return self._replace( - ids=self.ids[key], - prev_a=self.prev_a[key], - visited_=self.visited_[key], - lengths=self.lengths[key], - cur_total_prize=self.cur_total_prize[key], - cur_total_penalty=self.cur_total_penalty[key], - cur_coord=self.cur_coord[key], - ) - return super(StatePCTSP, self).__getitem__(key) + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + return self._replace( + ids=self.ids[key], + prev_a=self.prev_a[key], + visited_=self.visited_[key], + lengths=self.lengths[key], + cur_total_prize=self.cur_total_prize[key], + cur_total_penalty=self.cur_total_penalty[key], + cur_coord=self.cur_coord[key], + ) # Warning: cannot override len of NamedTuple, len should be number of fields, not batch size # def __len__(self): diff --git a/problems/tsp/state_tsp.py b/problems/tsp/state_tsp.py index 14726078..47306054 100644 --- a/problems/tsp/state_tsp.py +++ b/problems/tsp/state_tsp.py @@ -28,16 +28,15 @@ def visited(self): return mask_long2bool(self.visited_, n=self.loc.size(-2)) def __getitem__(self, key): - if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor: - return self._replace( - ids=self.ids[key], - first_a=self.first_a[key], - prev_a=self.prev_a[key], - visited_=self.visited_[key], - lengths=self.lengths[key], - cur_coord=self.cur_coord[key] if self.cur_coord is not None else None, - ) - return super(StateTSP, self).__getitem__(key) + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + return self._replace( + ids=self.ids[key], + first_a=self.first_a[key], + prev_a=self.prev_a[key], + visited_=self.visited_[key], + lengths=self.lengths[key], + cur_coord=self.cur_coord[key] if self.cur_coord is not None else None, + ) @staticmethod def initialize(loc, visited_dtype=torch.uint8): diff --git a/problems/vrp/state_cvrp.py b/problems/vrp/state_cvrp.py index 83764b51..81ba1e67 100644 --- a/problems/vrp/state_cvrp.py +++ b/problems/vrp/state_cvrp.py @@ -34,16 +34,15 @@ def dist(self): return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1) def __getitem__(self, key): - if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor: - return self._replace( - ids=self.ids[key], - prev_a=self.prev_a[key], - used_capacity=self.used_capacity[key], - visited_=self.visited_[key], - lengths=self.lengths[key], - cur_coord=self.cur_coord[key], - ) - return super(StateCVRP, self).__getitem__(key) + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + return self._replace( + ids=self.ids[key], + prev_a=self.prev_a[key], + used_capacity=self.used_capacity[key], + visited_=self.visited_[key], + lengths=self.lengths[key], + cur_coord=self.cur_coord[key], + ) # Warning: cannot override len of NamedTuple, len should be number of fields, not batch size # def __len__(self): diff --git a/problems/vrp/state_sdvrp.py b/problems/vrp/state_sdvrp.py index b4247c1c..1970602b 100644 --- a/problems/vrp/state_sdvrp.py +++ b/problems/vrp/state_sdvrp.py @@ -22,16 +22,15 @@ class StateSDVRP(NamedTuple): VEHICLE_CAPACITY = 1.0 # Hardcoded def __getitem__(self, key): - if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor: - return self._replace( - ids=self.ids[key], - prev_a=self.prev_a[key], - used_capacity=self.used_capacity[key], - demands_with_depot=self.demands_with_depot[key], - lengths=self.lengths[key], - cur_coord=self.cur_coord[key], - ) - return super(StateSDVRP, self).__getitem__(key) + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + return self._replace( + ids=self.ids[key], + prev_a=self.prev_a[key], + used_capacity=self.used_capacity[key], + demands_with_depot=self.demands_with_depot[key], + lengths=self.lengths[key], + cur_coord=self.cur_coord[key], + ) @staticmethod def initialize(input): diff --git a/utils/beam_search.py b/utils/beam_search.py index 6da22729..5e68d206 100644 --- a/utils/beam_search.py +++ b/utils/beam_search.py @@ -70,15 +70,14 @@ def ids(self): return self.state.ids.view(-1) # Need to flat as state has steps dimension def __getitem__(self, key): - if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor: - return self._replace( - # ids=self.ids[key], - score=self.score[key] if self.score is not None else None, - state=self.state[key], - parent=self.parent[key] if self.parent is not None else None, - action=self.action[key] if self.action is not None else None - ) - return super(BatchBeam, self).__getitem__(key) + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + return self._replace( + # ids=self.ids[key], + score=self.score[key] if self.score is not None else None, + state=self.state[key], + parent=self.parent[key] if self.parent is not None else None, + action=self.action[key] if self.action is not None else None + ) # Do not use __len__ since this is used by namedtuple internally and should be number of fields # def __len__(self): @@ -207,15 +206,13 @@ def __getitem__(self, key): assert not isinstance(key, slice), "CachedLookup does not support slicing, " \ "you can slice the result of an index operation instead" - if torch.is_tensor(key): # If tensor, idx all tensors by this tensor: - - if self.key is None: - self.key = key - self.current = self.orig[key] - elif len(key) != len(self.key) or (key != self.key).any(): - self.key = key - self.current = self.orig[key] + assert torch.is_tensor(key) # If tensor, idx all tensors by this tensor: - return self.current + if self.key is None: + self.key = key + self.current = self.orig[key] + elif len(key) != len(self.key) or (key != self.key).any(): + self.key = key + self.current = self.orig[key] - return super(CachedLookup, self).__getitem__(key) + return self.current From c1d77bdc181771da30563db484dd72104ab32b80 Mon Sep 17 00:00:00 2001 From: Wouter Kool Date: Fri, 20 Nov 2020 13:59:50 +0100 Subject: [PATCH 2/4] Fix: long division should perform floor division (changed behaviour pytorch 3.7) --- nets/attention_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nets/attention_model.py b/nets/attention_model.py index 4fda7e01..d8d2fab1 100644 --- a/nets/attention_model.py +++ b/nets/attention_model.py @@ -171,7 +171,7 @@ def propose_expansions(self, beam, fixed, expand_size=None, normalize=False, max flat_feas = flat_score > -1e10 # != -math.inf triggers # Parent is row idx of ind_topk, can be found by enumerating elements and dividing by number of columns - flat_parent = torch.arange(flat_action.size(-1), out=flat_action.new()) / ind_topk.size(-1) + flat_parent = torch.arange(flat_action.size(-1), out=flat_action.new()) // ind_topk.size(-1) # Filter infeasible feas_ind_2d = torch.nonzero(flat_feas) From b8c9b8a02d8955cb39be06a4130b3233880927b2 Mon Sep 17 00:00:00 2001 From: Wouter Kool Date: Fri, 20 Nov 2020 14:41:06 +0100 Subject: [PATCH 3/4] Alternative implementations for attention in comments for clarity --- nets/graph_encoder.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nets/graph_encoder.py b/nets/graph_encoder.py index a5cf68f1..90a302c8 100644 --- a/nets/graph_encoder.py +++ b/nets/graph_encoder.py @@ -109,6 +109,15 @@ def forward(self, q, h=None, mask=None): self.W_out.view(-1, self.embed_dim) ).view(batch_size, n_query, self.embed_dim) + # Alternative: + # headst = heads.transpose(0, 1) # swap the dimensions for batch and heads to align it for the matmul + # # proj_h = torch.einsum('bhni,hij->bhnj', headst, self.W_out) + # projected_heads = torch.matmul(headst, self.W_out) + # out = torch.sum(projected_heads, dim=1) # sum across heads + + # Or: + # out = torch.einsum('hbni,hij->bnj', heads, self.W_out) + return out From 5fa0b17e72b28fd9464fc46753fba04a5c17cf74 Mon Sep 17 00:00:00 2001 From: Wouter Kool Date: Fri, 20 Nov 2020 14:48:09 +0100 Subject: [PATCH 4/4] Make embed_dim required and fix size mismatch bug when key_dim and val_dim are different --- nets/graph_encoder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nets/graph_encoder.py b/nets/graph_encoder.py index 90a302c8..e3a4cf88 100644 --- a/nets/graph_encoder.py +++ b/nets/graph_encoder.py @@ -19,14 +19,13 @@ def __init__( self, n_heads, input_dim, - embed_dim=None, + embed_dim, val_dim=None, key_dim=None ): super(MultiHeadAttention, self).__init__() if val_dim is None: - assert embed_dim is not None, "Provide either embed_dim or val_dim" val_dim = embed_dim // n_heads if key_dim is None: key_dim = val_dim @@ -43,8 +42,7 @@ def __init__( self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) - if embed_dim is not None: - self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim)) + self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim)) self.init_parameters()