diff --git a/deepchem/feat/molecule_featurizers/snap_featurizer.py b/deepchem/feat/molecule_featurizers/snap_featurizer.py index 00110aafba..451b3fff19 100644 --- a/deepchem/feat/molecule_featurizers/snap_featurizer.py +++ b/deepchem/feat/molecule_featurizers/snap_featurizer.py @@ -5,7 +5,7 @@ allowable_features = { 'possible_atomic_num_list': - list(range(1, 119)), + list(range(0, 119)), # 0 represents a masked atom 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], 'possible_chirality_list': [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, @@ -23,8 +23,11 @@ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'possible_bonds': [ - Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC + 0, # 0 represents a masked bond + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC ], 'possible_bond_dirs': [ # only for double bond stereo information Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT, diff --git a/deepchem/models/tests/test_gnn.py b/deepchem/models/tests/test_gnn.py index 2ea4828c10..8b405c12d4 100644 --- a/deepchem/models/tests/test_gnn.py +++ b/deepchem/models/tests/test_gnn.py @@ -92,7 +92,7 @@ def test_GNN_node_masking(): from deepchem.models.torch_models.gnn import GNNModular dataset, _ = get_regression_dataset() - model = GNNModular(task="mask_nodes") + model = GNNModular(task="mask_nodes", device="cpu") loss1 = model.fit(dataset, nb_epoch=5) loss2 = model.fit(dataset, nb_epoch=5) assert loss2 < loss1 diff --git a/deepchem/models/torch_models/gnn.py b/deepchem/models/torch_models/gnn.py index 3ab34fc1f5..e409591b13 100644 --- a/deepchem/models/torch_models/gnn.py +++ b/deepchem/models/torch_models/gnn.py @@ -1,4 +1,5 @@ import random +import copy import torch import numpy as np from torch_geometric.nn import GINEConv, global_add_pool, global_mean_pool, global_max_pool @@ -337,7 +338,7 @@ def build_components(self): 1) # -1 to remove mask token linear_pred_edges = torch.nn.Linear( self.emb_dim, - num_edge_type - 2) # -2 to remove mask token and self-loop + num_edge_type - 1) # -1 to remove mask token components.update({ 'linear_pred_nodes': linear_pred_nodes, 'linear_pred_edges': linear_pred_edges @@ -511,14 +512,14 @@ def default_generator( yield ([X_b], [y_b], [w_b]) -def negative_edge_sampler(data: BatchGraphData): +def negative_edge_sampler(input_graph: BatchGraphData): """ NegativeEdge is a function that adds negative edges to the input graph data. It randomly samples negative edges (edges that do not exist in the original graph) and adds them to the input graph data. The number of negative edges added is equal to half the number of edges in the original graph. This is useful for tasks like edge prediction, where the model needs to learn to differentiate between existing and non-existing edges. Parameters ---------- - data: dc.feat.graph_data.BatchGraphData + input_graph: dc.feat.graph_data.BatchGraphData The input graph data. Returns @@ -551,7 +552,7 @@ def negative_edge_sampler(data: BatchGraphData): >>> batched_graph = batched_graph.numpy_to_torch() >>> neg_sampled = negative_edge_sampler(batched_graph) """ - import torch + data = copy.deepcopy(input_graph) num_nodes = data.num_nodes num_edges = data.num_edges @@ -580,16 +581,18 @@ def negative_edge_sampler(data: BatchGraphData): return data -def mask_nodes(data: BatchGraphData, +def mask_nodes(input_graph: BatchGraphData, mask_rate, masked_node_indices=None, mask_edge=True): """ - Mask nodes and their connected edges in a PyTorch geometric data object. + Mask nodes and their connected edges in a BatchGraphData object. + + This function assumes that the first node feature is the atomic number, for example with the SNAPFeaturizer. It will set masked nodes' features to 0. Parameters ---------- - data: dc.feat.BatchGraphData + input_graph: dc.feat.BatchGraphData Assume that the edge ordering is the default PyTorch geometric ordering, where the two directions of a single edge occur in pairs. Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) @@ -608,6 +611,7 @@ def mask_nodes(data: BatchGraphData, - data.mask_edge_label """ + data = copy.deepcopy(input_graph) if masked_node_indices is None: # sample x distinct nodes to be masked, based on mask rate. But @@ -679,7 +683,7 @@ def mask_nodes(data: BatchGraphData, return data -def mask_edges(data: BatchGraphData, +def mask_edges(input_graph: BatchGraphData, mask_rate: float, masked_edge_indices=None): """ @@ -689,7 +693,7 @@ def mask_edges(data: BatchGraphData, Parameters ---------- - data: dc.feat.BatchGraphData + input_graph: dc.feat.BatchGraphData Assume that the edge ordering is the default PyTorch geometric ordering, where the two directions of a single edge occur in pairs. Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) @@ -704,6 +708,8 @@ def mask_edges(data: BatchGraphData, - data.mask_edge_labels: corresponding ground truth edge feature for each masked edge - data.edge_attr: modified in place: the edge features (both directions) that correspond to the masked edges have the masked edge feature """ + data = copy.deepcopy(input_graph) + if masked_edge_indices is None: # sample x distinct edges to be masked, based on mask rate. But # will sample at least 1 edge