Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tonydavis629 committed Apr 12, 2023
1 parent 8c0b42e commit 61be4a7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion deepchem/models/tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_GNN_node_masking():
loss1 = model.fit(dataset, nb_epoch=5)
loss2 = model.fit(dataset, nb_epoch=5)
assert loss2 < loss1
test_GNN_node_masking()


@pytest.mark.torch
def test_GNN_edge_masking():
Expand Down
25 changes: 13 additions & 12 deletions deepchem/models/torch_models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,8 @@ def mask_nodes(data: BatchGraphData,
mask_node_labels_list = []
for node_idx in masked_node_indices:
mask_node_labels_list.append(data.node_features[node_idx].view(1, -1))
data.mask_node_label = torch.cat(mask_node_labels_list, dim=0)[:, 0].long()
data.masked_node_indices = torch.tensor(masked_node_indices)
data.mask_node_label = torch.cat(mask_node_labels_list, dim=0)[:, 0].long() # type: ignore
data.masked_node_indices = torch.tensor(masked_node_indices) # type: ignore

# modify the original node feature of the masked node
num_node_feats = data.node_features.size()[1]
Expand Down Expand Up @@ -653,8 +653,9 @@ def mask_nodes(data: BatchGraphData,
mask_edge_labels_list.append(data.edge_features[edge_idx].view(
1, -1))

data.mask_edge_label = torch.cat(mask_edge_labels_list,
dim=0)[:, 0].long()
data.mask_edge_label = torch.cat( # type: ignore
mask_edge_labels_list,
dim=0)[:, 0].long()
# modify the original edge features of the edges connected to the mask nodes
num_edge_feat = data.edge_features.size()[1]
for edge_idx in connected_edge_indices:
Expand All @@ -663,12 +664,12 @@ def mask_nodes(data: BatchGraphData,
# original implementation, where the masked features are represented by the
# the last feature token 4.
# link to source: https://github.com/snap-stanford/pretrain-gnns/blob/08f126ac13623e551a396dd5e511d766f9d4f8ff/chem/util.py#L268
data.connected_edge_indices = torch.tensor(

data.connected_edge_indices = torch.tensor( # type: ignore
connected_edge_indices[::2])
else:
data.mask_edge_label = torch.empty((0, 2)).to(torch.int64)
data.connected_edge_indices = torch.tensor(
data.mask_edge_label = torch.empty((0, 2)).to(torch.int64) # type: ignore
data.connected_edge_indices = torch.tensor( # type: ignore
connected_edge_indices).to(torch.int64)

return data
Expand Down Expand Up @@ -710,16 +711,16 @@ def mask_edges(data: BatchGraphData,
2 * i for i in random.sample(range(num_edges), sample_size)
]

data.masked_edge_idx = torch.tensor(
np.array(masked_edge_indices)) # type: ignore
data.masked_edge_idx = torch.tensor( # type: ignore
np.array(masked_edge_indices))

# create ground truth edge features for the edges that correspond to
# the masked indices
mask_edge_labels_list = []
for idx in masked_edge_indices:
mask_edge_labels_list.append(data.edge_features[idx].view(1, -1))
data.mask_edge_label = torch.cat(mask_edge_labels_list,
dim=0) # type: ignore
data.mask_edge_label = torch.cat(mask_edge_labels_list, # type: ignore
dim=0)

# created new masked edge_attr, where both directions of the masked
# edges have masked edge type. For message passing in gcn
Expand Down

0 comments on commit 61be4a7

Please sign in to comment.