Skip to content

Commit

Permalink
docs details
Browse files Browse the repository at this point in the history
  • Loading branch information
tonydavis629 committed Apr 11, 2023
1 parent a4f025f commit d2eabc3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 4 additions & 2 deletions deepchem/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,9 +911,11 @@ def loss(atom_vocab_task_atom_pred: torch.Tensor,

class EdgePredictionLoss(Loss):
"""
Unsupervised graph edge prediction loss.
EdgePredictionLoss is an unsupervised graph edge prediction loss function that calculates the loss based on the similarity between node embeddings for positive and negative edge pairs. This loss function is designed for graph neural networks and is particularly useful for pre-training tasks.
The inputs in this loss must be a BatchGraphData object transformed by the negative_edge_sampler molecule.
This loss function encourages the model to learn node embeddings that can effectively distinguish between true edges (positive samples) and false edges (negative samples) in the graph. The loss is computed by comparing the similarity scores (dot product) of node embeddings for positive and negative edge pairs. The goal is to maximize the similarity for positive pairs and minimize it for negative pairs.
To use this loss function, the input must be a BatchGraphData object transformed by the negative_edge_sampler. The loss function takes the node embeddings and the input graph data (with positive and negative edge pairs) as inputs and returns the edge prediction loss.
Examples
--------
Expand Down
4 changes: 4 additions & 0 deletions deepchem/models/torch_models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ def default_generator(
mode: str = 'fit',
deterministic: bool = True,
pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
"""
This default generator is modified from the default generator in dc.models.tensorgraph.tensor_graph.py to support multitask classification. If the task is classification, the labels y_b are converted to a one-hot encoding and reshaped according to the number of tasks and classes.
"""

for epoch in range(epochs):
for (X_b, y_b, w_b,
ids_b) in dataset.iterbatches(batch_size=self.batch_size,
Expand Down

0 comments on commit d2eabc3

Please sign in to comment.