Skip to content

Commit

Permalink
set 0 as mask token
Browse files Browse the repository at this point in the history
  • Loading branch information
tonydavis629 committed Apr 13, 2023
1 parent 6197ca4 commit 6939e3c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
9 changes: 6 additions & 3 deletions deepchem/feat/molecule_featurizers/snap_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

allowable_features = {
'possible_atomic_num_list':
list(range(1, 119)),
list(range(0, 119)), # 0 represents a masked atom
'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
'possible_chirality_list': [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Expand All @@ -23,8 +23,11 @@
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'possible_bonds': [
Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC
0, # 0 represents a masked bond
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC
],
'possible_bond_dirs': [ # only for double bond stereo information
Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
Expand Down
2 changes: 1 addition & 1 deletion deepchem/models/tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_GNN_node_masking():
from deepchem.models.torch_models.gnn import GNNModular

dataset, _ = get_regression_dataset()
model = GNNModular(task="mask_nodes")
model = GNNModular(task="mask_nodes", device="cpu")
loss1 = model.fit(dataset, nb_epoch=5)
loss2 = model.fit(dataset, nb_epoch=5)
assert loss2 < loss1
Expand Down
24 changes: 15 additions & 9 deletions deepchem/models/torch_models/gnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import copy
import torch
import numpy as np
from torch_geometric.nn import GINEConv, global_add_pool, global_mean_pool, global_max_pool
Expand Down Expand Up @@ -337,7 +338,7 @@ def build_components(self):
1) # -1 to remove mask token
linear_pred_edges = torch.nn.Linear(
self.emb_dim,
num_edge_type - 2) # -2 to remove mask token and self-loop
num_edge_type - 1) # -1 to remove mask token
components.update({
'linear_pred_nodes': linear_pred_nodes,
'linear_pred_edges': linear_pred_edges
Expand Down Expand Up @@ -511,14 +512,14 @@ def default_generator(
yield ([X_b], [y_b], [w_b])


def negative_edge_sampler(data: BatchGraphData):
def negative_edge_sampler(input_graph: BatchGraphData):
"""
NegativeEdge is a function that adds negative edges to the input graph data. It randomly samples negative edges (edges that do not exist in the original graph) and adds them to the input graph data.
The number of negative edges added is equal to half the number of edges in the original graph. This is useful for tasks like edge prediction, where the model needs to learn to differentiate between existing and non-existing edges.
Parameters
----------
data: dc.feat.graph_data.BatchGraphData
input_graph: dc.feat.graph_data.BatchGraphData
The input graph data.
Returns
Expand Down Expand Up @@ -551,7 +552,7 @@ def negative_edge_sampler(data: BatchGraphData):
>>> batched_graph = batched_graph.numpy_to_torch()
>>> neg_sampled = negative_edge_sampler(batched_graph)
"""
import torch
data = copy.deepcopy(input_graph)

num_nodes = data.num_nodes
num_edges = data.num_edges
Expand Down Expand Up @@ -580,16 +581,18 @@ def negative_edge_sampler(data: BatchGraphData):
return data


def mask_nodes(data: BatchGraphData,
def mask_nodes(input_graph: BatchGraphData,
mask_rate,
masked_node_indices=None,
mask_edge=True):
"""
Mask nodes and their connected edges in a PyTorch geometric data object.
Mask nodes and their connected edges in a BatchGraphData object.
This function assumes that the first node feature is the atomic number, for example with the SNAPFeaturizer. It will set masked nodes' features to 0.
Parameters
----------
data: dc.feat.BatchGraphData
input_graph: dc.feat.BatchGraphData
Assume that the edge ordering is the default PyTorch geometric ordering, where the two directions of a single edge occur in pairs.
Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2]])
Expand All @@ -608,6 +611,7 @@ def mask_nodes(data: BatchGraphData,
- data.mask_edge_label
"""
data = copy.deepcopy(input_graph)

if masked_node_indices is None:
# sample x distinct nodes to be masked, based on mask rate. But
Expand Down Expand Up @@ -679,7 +683,7 @@ def mask_nodes(data: BatchGraphData,
return data


def mask_edges(data: BatchGraphData,
def mask_edges(input_graph: BatchGraphData,
mask_rate: float,
masked_edge_indices=None):
"""
Expand All @@ -689,7 +693,7 @@ def mask_edges(data: BatchGraphData,
Parameters
----------
data: dc.feat.BatchGraphData
input_graph: dc.feat.BatchGraphData
Assume that the edge ordering is the default PyTorch geometric ordering, where the two directions of a single edge occur in pairs.
Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2]])
Expand All @@ -704,6 +708,8 @@ def mask_edges(data: BatchGraphData,
- data.mask_edge_labels: corresponding ground truth edge feature for each masked edge
- data.edge_attr: modified in place: the edge features (both directions) that correspond to the masked edges have the masked edge feature
"""
data = copy.deepcopy(input_graph)

if masked_edge_indices is None:
# sample x distinct edges to be masked, based on mask rate. But
# will sample at least 1 edge
Expand Down

0 comments on commit 6939e3c

Please sign in to comment.