Skip to content

Commit

Permalink
Update (dmlc#1366)
Browse files Browse the repository at this point in the history
  • Loading branch information
mufeili authored Mar 16, 2020
1 parent 1662757 commit 2ce0e21
Show file tree
Hide file tree
Showing 55 changed files with 309 additions and 63 deletions.
12 changes: 9 additions & 3 deletions apps/life_sci/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@ Depending on the features you want to use, you may need to manually install the
- RDKit 2018.09.3
- We recommend installation with `conda install -c conda-forge rdkit==2018.09.3`. For other installation recipes,
see the [official documentation](https://www.rdkit.org/docs/Install.html).
- (optional) MDTraj
- We recommend installation with `conda install -c conda-forge mdtraj`. For alternative ways of installation,
see the [official documentation](http://mdtraj.org/1.9.3/installation.html).

## Installation

To install the package,

```bash
cd python
python setup.py install
```

## Organization

Expand Down
26 changes: 13 additions & 13 deletions apps/life_sci/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,49 @@ We provide various examples across 3 applications -- property prediction, genera
## Datasets/Benchmarks

- MoleculeNet: A Benchmark for Molecular Machine Learning [[paper]](https://arxiv.org/abs/1703.00564), [[website]](http://moleculenet.ai/)
- [Tox21 with DGL](../dgllife/data/tox21.py)
- [PDBBind with DGL](../dgllife/data/pdbbind.py)
- [Tox21 with DGL](../python/dgllife/data/tox21.py)
- [PDBBind with DGL](../python/dgllife/data/pdbbind.py)
- Alchemy: A Quantum Chemistry Dataset for Benchmarking AI Models [[paper]](https://arxiv.org/abs/1906.09427), [[github]](https://github.com/tencent-alchemy/Alchemy)
- [Alchemy with DGL](../dgllife/data/alchemy.py)
- [Alchemy with DGL](../python/dgllife/data/alchemy.py)

## Property Prediction

- Semi-Supervised Classification with Graph Convolutional Networks (GCN) [[paper]](https://arxiv.org/abs/1609.02907), [[github]](https://github.com/tkipf/gcn)
- [GCN-Based Predictor with DGL](../dgllife/model/model_zoo/gcn_predictor.py)
- [GCN-Based Predictor with DGL](../python/dgllife/model/model_zoo/gcn_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- Graph Attention Networks (GAT) [[paper]](https://arxiv.org/abs/1710.10903), [[github]](https://github.com/PetarV-/GAT)
- [GAT-Based Predictor with DGL](../dgllife/model/model_zoo/gat_predictor.py)
- [GAT-Based Predictor with DGL](../python/dgllife/model/model_zoo/gat_predictor.py)
- [Example for Molecule Classification](property_prediction/classification.py)
- SchNet: A continuous-filter convolutional neural network for modeling quantum interactions [[paper]](https://arxiv.org/abs/1706.08566), [[github]](https://github.com/atomistic-machine-learning/SchNet)
- [SchNet with DGL](../dgllife/model/model_zoo/schnet_predictor.py)
- [SchNet with DGL](../python/dgllife/model/model_zoo/schnet_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (MGCN) [[paper]](https://arxiv.org/abs/1906.11081)
- [MGCN with DGL](../dgllife/model/model_zoo/mgcn_predictor.py)
- [MGCN with DGL](../python/dgllife/model/model_zoo/mgcn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Neural Message Passing for Quantum Chemistry (MPNN) [[paper]](https://arxiv.org/abs/1704.01212), [[github]](https://github.com/brain-research/mpnn)
- [MPNN with DGL](../dgllife/model/model_zoo/mpnn_predictor.py)
- [MPNN with DGL](../python/dgllife/model/model_zoo/mpnn_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)
- Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (AttentiveFP) [[paper]](https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b00959)
- [AttentiveFP with DGL](../dgllife/model/model_zoo/attentivefp_predictor.py)
- [AttentiveFP with DGL](../python/dgllife/model/model_zoo/attentivefp_predictor.py)
- [Example for Molecule Regression](property_prediction/regression.py)

## Generative Models

- Learning Deep Generative Models of Graphs (DGMG) [[paper]](https://arxiv.org/abs/1803.03324)
- [DGMG with DGL](../dgllife/model/model_zoo/dgmg.py)
- [DGMG with DGL](../python/dgllife/model/model_zoo/dgmg.py)
- [Example Training Script](generative_models/dgmg)
- Junction Tree Variational Autoencoder for Molecular Graph Generation (JTNN) [[paper]](https://arxiv.org/abs/1802.04364)
- [JTNN with DGL](../dgllife/model/model_zoo/jtnn)
- [JTNN with DGL](../python/dgllife/model/model_zoo/jtnn)
- [Example Training Script](generative_models/jtnn)

## Binding Affinity Prediction

- Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity (ACNN) [[paper]](https://arxiv.org/abs/1703.10603), [[github]](https://github.com/deepchem/deepchem/tree/master/contrib/atomicconv)
- [ACNN with DGL](../dgllife/model/model_zoo/acnn.py)
- [ACNN with DGL](../python/dgllife/model/model_zoo/acnn.py)
- [Example Training Script](binding_affinity_prediction)

## Reaction Prediction
- A graph-convolutional neural network model for the prediction of chemical reactivity [[paper]](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract), [[github]](https://github.com/connorcoley/rexgen_direct)
- An earlier version was published in NeurIPS 2017 as "Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network" [[paper]](https://arxiv.org/abs/1709.04555)
- [WLN with DGL for Reaction Center Prediction](../dgllife/model/model_zoo/wln_reaction_center.py)
- [WLN with DGL for Reaction Center Prediction](../python/dgllife/model/model_zoo/wln_reaction_center.py)
- [Example Script](reaction_prediction/rexgen_direct)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""Convert molecules into DGLGraphs."""
import numpy as np
import torch

from dgl import DGLGraph
from functools import partial
from rdkit import Chem
from rdkit.Chem import rdmolfiles, rdmolops

try:
import mdtraj
except ImportError:
pass
from sklearn.neighbors import NearestNeighbors

__all__ = ['mol_to_graph',
'smiles_to_bigraph',
'mol_to_bigraph',
'smiles_to_complete_graph',
'mol_to_complete_graph',
'k_nearest_neighbors']
'k_nearest_neighbors',
'mol_to_nearest_neighbor_graph',
'smiles_to_nearest_neighbor_graph']

def mol_to_graph(mol, graph_constructor, node_featurizer, edge_featurizer, canonical_atom_order):
"""Convert an RDKit molecule object into a DGLGraph and featurize for it.
Expand Down Expand Up @@ -262,51 +260,207 @@ def smiles_to_complete_graph(smiles, add_self_loop=False,
return mol_to_complete_graph(mol, add_self_loop, node_featurizer,
edge_featurizer, canonical_atom_order)

def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors):
"""Find k nearest neighbors for each atom based on the 3D coordinates and
return the resulted edges.
def k_nearest_neighbors(coordinates, neighbor_cutoff, max_num_neighbors=None,
p_distance=2, self_loops=False):
"""Find k nearest neighbors for each atom
For each atom, find its k nearest neighbors and return edges
from these neighbors to it.
We do not guarantee that the edges are sorted according to the distance
between atoms.
Parameters
----------
coordinates : numpy.ndarray of shape (N, 3)
The 3D coordinates of atoms in the molecule. N for the number of atoms.
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
Distance cutoff to define 'neighboring'.
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of closest neighbors
allowed for each atom.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
self_loops : bool
Whether to allow a node to be its own neighbor. Default to False.
Returns
-------
srcs : list of int
Source nodes.
dsts : list of int
Destination nodes.
Destination nodes, corresponding to ``srcs``.
distances : list of float
Distances between the end nodes.
Distances between the end nodes, corresponding to ``srcs`` and ``dsts``.
"""
num_atoms = coordinates.shape[0]
traj = mdtraj.Trajectory(coordinates.reshape((1, num_atoms, 3)), None)
neighbors = mdtraj.geometry.compute_neighborlist(traj, neighbor_cutoff)
srcs, dsts, distances = [], [], []
model = NearestNeighbors(radius=neighbor_cutoff, p=p_distance)
model.fit(coordinates)
dists_, nbrs = model.radius_neighbors(coordinates)
srcs, dsts, dists = [], [], []
for i in range(num_atoms):
delta = coordinates[i] - coordinates.take(neighbors[i], axis=0)
dist = np.linalg.norm(delta, axis=1)
if max_num_neighbors is not None and len(neighbors[i]) > max_num_neighbors:
sorted_neighbors = list(zip(dist, neighbors[i]))
dists_i = dists_[i].tolist()
nbrs_i = nbrs[i].tolist()
if not self_loops:
dists_i.remove(0)
nbrs_i.remove(i)
if max_num_neighbors is not None and len(nbrs_i) > max_num_neighbors:
packed_nbrs = list(zip(dists_i, nbrs_i))
# Sort neighbors based on distance from smallest to largest
sorted_neighbors.sort(key=lambda tup: tup[0])
packed_nbrs.sort(key=lambda tup: tup[0])
dists_i, nbrs_i = map(list, zip(*packed_nbrs))
dsts.extend([i for _ in range(max_num_neighbors)])
srcs.extend([int(sorted_neighbors[j][1]) for j in range(max_num_neighbors)])
distances.extend([float(sorted_neighbors[j][0]) for j in range(max_num_neighbors)])
srcs.extend(nbrs_i[:max_num_neighbors])
dists.extend(dists_i[:max_num_neighbors])
else:
dsts.extend([i for _ in range(len(neighbors[i]))])
srcs.extend(neighbors[i].tolist())
distances.extend(dist.tolist())
dsts.extend([i for _ in range(len(nbrs_i))])
srcs.extend(nbrs_i)
dists.extend(dists_i)

return srcs, dsts, dists

def mol_to_nearest_neighbor_graph(mol,
coordinates,
neighbor_cutoff,
max_num_neighbors=None,
p_distance=2,
add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True,
keep_dists=False,
dist_field='dist'):
"""Convert an RDKit molecule into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule holder
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
"""
if canonical_atom_order:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)

srcs, dsts, dists = k_nearest_neighbors(coordinates=coordinates,
neighbor_cutoff=neighbor_cutoff,
max_num_neighbors=max_num_neighbors,
p_distance=p_distance,
self_loops=add_self_loop)
g = DGLGraph()

# Add nodes first since some nodes may be completely isolated
num_atoms = mol.GetNumAtoms()
g.add_nodes(num_atoms)

# Add edges
g.add_edges(srcs, dsts)

if node_featurizer is not None:
g.ndata.update(node_featurizer(mol))

return srcs, dsts, distances
if edge_featurizer is not None:
g.edata.update(edge_featurizer(mol))

# Todo(Mufei): smiles_to_knn_graph, mol_to_knn_graph
if keep_dists:
assert dist_field not in g.edata, \
'Expect {} to be reserved for distance between neighboring atoms.'
g.edata[dist_field] = torch.tensor(dists).float().reshape(-1, 1)

return g

def smiles_to_nearest_neighbor_graph(smiles,
coordinates,
neighbor_cutoff,
max_num_neighbors=None,
p_distance=2,
add_self_loop=False,
node_featurizer=None,
edge_featurizer=None,
canonical_atom_order=True,
keep_dists=False,
dist_field='dist'):
"""Convert a SMILES into a nearest neighbor graph and featurize for it.
Different from bigraph and complete graph, the nearest neighbor graph
may not be symmetric since i is the closest neighbor of j does not
necessarily suggest the other way.
Parameters
----------
smiles : str
String of SMILES
coordinates : numpy.ndarray of shape (N, D)
The coordinates of atoms in the molecule. N for the number of atoms
and D for the dimensions of the coordinates.
neighbor_cutoff : float
If the distance between a pair of nodes is larger than neighbor_cutoff,
they will not be considered as neighboring nodes.
max_num_neighbors : int or None.
If not None, then this specifies the maximum number of neighbors
allowed for each atom. Default to None.
p_distance : int
We compute the distance between neighbors using Minkowski (:math:`l_p`)
distance. When ``p_distance = 1``, Minkowski distance is equivalent to
Manhattan distance. When ``p_distance = 2``, Minkowski distance is
equivalent to the standard Euclidean distance. Default to 2.
add_self_loop : bool
Whether to add self loops in DGLGraphs. Default to False.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
canonical_atom_order : bool
Whether to use a canonical order of atoms returned by RDKit. Setting it
to true might change the order of atoms in the graph constructed. Default
to True.
keep_dists : bool
Whether to store the distance between neighboring atoms in ``edata`` of the
constructed DGLGraphs. Default to False.
dist_field : str
Field for storing distance between neighboring atoms in ``edata``. This comes
into effect only when ``keep_dists=True``. Default to ``'dist'``.
"""
mol = Chem.MolFromSmiles(smiles)
return mol_to_nearest_neighbor_graph(
mol, coordinates, neighbor_cutoff, max_num_neighbors, p_distance, add_self_loop,
node_featurizer, edge_featurizer, canonical_atom_order, keep_dists, dist_field)
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion apps/life_sci/setup.py → apps/life_sci/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if package.startswith('dgllife')],
install_requires=[
'torch>=1'
'scikit-learn>=0.21.2',
'scikit-learn>=0.22.2',
'pandas>=0.25.1',
'requests>=2.22.0',
'tqdm'
Expand Down
Loading

0 comments on commit 2ce0e21

Please sign in to comment.