Skip to content

Commit

Permalink
Add class-resolver package (pyg-team#4041)
Browse files Browse the repository at this point in the history
* add class_resolver

* typo

* typo

* update python version
  • Loading branch information
rusty1s authored Feb 9, 2022
1 parent 6002170 commit 335e6ad
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/building-pyg-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-18.04, macos-10.15, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]
torch-version: [1.9.0, 1.10.0]
cuda-version: ['cpu', 'cu102', 'cu111', 'cu113']
exclude:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/building-rusty1s-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-18.04, macos-10.15, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]
torch-version: [1.9.0, 1.10.0]
cuda-version: ['cpu', 'cu102', 'cu111', 'cu113']
exclude:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
python-version: [3.6, 3.9]
python-version: [3.7, 3.9]
torch-version: [1.9.0, 1.10.0]
include:
- torch-version: 1.9.0
Expand Down
2 changes: 1 addition & 1 deletion examples/correct_and_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP([dataset.num_features, 200, 200, dataset.num_classes], dropout=0.5,
batch_norm=True, relu_first=True).to(device)
batch_norm=True, act_first=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
'pyparsing',
'hydra-core',
'scikit-learn',
'class-resolver>=0.3.2',
'googledrivedownloader',
]

Expand Down Expand Up @@ -54,7 +55,7 @@
'graph-neural-networks',
'graph-convolutional-networks',
],
python_requires='>=3.6',
python_requires='>=3.7',
install_requires=install_requires,
extras_require={
'full': full_install_requires,
Expand Down
8 changes: 4 additions & 4 deletions test/nn/models/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from torch_geometric.nn import MLP


@pytest.mark.parametrize('batch_norm,relu_first',
@pytest.mark.parametrize('batch_norm,act_first',
product([False, True], [False, True]))
def test_mlp(batch_norm, relu_first):
def test_mlp(batch_norm, act_first):
x = torch.randn(4, 16)

torch.manual_seed(12345)
mlp = MLP([16, 32, 32, 64], batch_norm=batch_norm, relu_first=relu_first)
mlp = MLP([16, 32, 32, 64], batch_norm=batch_norm, act_first=act_first)
assert str(mlp) == 'MLP(16, 32, 32, 64)'
out = mlp(x)
assert out.size() == (4, 64)
Expand All @@ -22,5 +22,5 @@ def test_mlp(batch_norm, relu_first):

torch.manual_seed(12345)
mlp = MLP(16, hidden_channels=32, out_channels=64, num_layers=3,
batch_norm=batch_norm, relu_first=relu_first)
batch_norm=batch_norm, act_first=act_first)
assert torch.allclose(mlp(x), out)
6 changes: 3 additions & 3 deletions torch_geometric/nn/models/linkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@ def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
if self.num_edge_layers > 1:
self.edge_norm = BatchNorm1d(hidden_channels)
channels = [hidden_channels] * num_edge_layers
self.edge_mlp = MLP(channels, dropout=0., relu_first=True)
self.edge_mlp = MLP(channels, dropout=0., act_first=True)

channels = [in_channels] + [hidden_channels] * num_node_layers
self.node_mlp = MLP(channels, dropout=0., relu_first=True)
self.node_mlp = MLP(channels, dropout=0., act_first=True)

self.cat_lin1 = torch.nn.Linear(hidden_channels, hidden_channels)
self.cat_lin2 = torch.nn.Linear(hidden_channels, hidden_channels)

channels = [hidden_channels] * num_layers + [out_channels]
self.final_mlp = MLP(channels, dropout=dropout, relu_first=True)
self.final_mlp = MLP(channels, dropout=dropout, act_first=True)

self.reset_parameters()

Expand Down
55 changes: 41 additions & 14 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from class_resolver.contrib.torch import activation_resolver
from torch import Tensor
from torch.nn import BatchNorm1d, Identity

Expand Down Expand Up @@ -32,8 +33,8 @@ class MLP(torch.nn.Module):
Args:
channel_list (List[int] or int, optional): List of input, intermediate
and output channels. :obj:`len(channel_list) - 1` denotes the
number of layers of the MLP (default: :obj:`None`)
and output channels such that :obj:`len(channel_list) - 1` denotes
the number of layers of the MLP (default: :obj:`None`)
in_channels (int, optional): Size of each input sample.
Will override :attr:`channel_list`. (default: :obj:`None`)
hidden_channels (int, optional): Size of each hidden sample.
Expand All @@ -44,10 +45,22 @@ class MLP(torch.nn.Module):
Will override :attr:`channel_list`. (default: :obj:`None`)
dropout (float, optional): Dropout probability of each hidden
embedding. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
batch_norm (bool, optional): If set to :obj:`False`, will not make use
of batch normalization. (default: :obj:`True`)
relu_first (bool, optional): If set to :obj:`True`, ReLU activation is
applied before batch normalization. (default: :obj:`False`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
batch_norm_kwargs (Dict[str, Any], optional): Arguments passed to
:class:`torch.nn.BatchNorm1d` in case :obj:`batch_norm == True`.
(default: :obj:`None`)
bias (bool, optional): If set to :obj:`False`, the module will not
learn additive biases. (default: :obj:`True`)
relu_first (bool, optional): Deprecated in favor of :obj:`act_first`.
(default: :obj:`False`)
"""
def __init__(
self,
Expand All @@ -58,10 +71,17 @@ def __init__(
out_channels: Optional[int] = None,
num_layers: Optional[int] = None,
dropout: float = 0.,
act: str = "relu",
batch_norm: bool = True,
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
batch_norm_kwargs: Optional[Dict[str, Any]] = None,
bias: bool = True,
relu_first: bool = False,
):
super().__init__()
act_first = act_first or relu_first # Backward compatibility.
batch_norm_kwargs = batch_norm_kwargs or {}

if isinstance(channel_list, int):
in_channels = channel_list
Expand All @@ -74,16 +94,23 @@ def __init__(
assert isinstance(channel_list, (tuple, list))
assert len(channel_list) >= 2
self.channel_list = channel_list

self.dropout = dropout
self.relu_first = relu_first
self.act = activation_resolver.make(act, act_kwargs)
self.act_first = act_first

self.lins = torch.nn.ModuleList()
for dims in zip(channel_list[:-1], channel_list[1:]):
self.lins.append(Linear(*dims))
pairwise = zip(channel_list[:-1], channel_list[1:])
for in_channels, out_channels in pairwise:
self.lins.append(Linear(in_channels, out_channels, bias=bias))

self.norms = torch.nn.ModuleList()
for dim in zip(channel_list[1:-1]):
self.norms.append(BatchNorm1d(dim) if batch_norm else Identity())
for hidden_channels in channel_list[1:-1]:
if batch_norm:
norm = BatchNorm1d(hidden_channels, **batch_norm_kwargs)
else:
norm = Identity()
self.norms.append(norm)

self.reset_parameters()

Expand Down Expand Up @@ -113,11 +140,11 @@ def forward(self, x: Tensor) -> Tensor:
""""""
x = self.lins[0](x)
for lin, norm in zip(self.lins[1:], self.norms):
if self.relu_first:
x = x.relu_()
if self.act_first:
x = self.act(x)
x = norm(x)
if not self.relu_first:
x = x.relu_()
if not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = lin.forward(x)
return x
Expand Down

0 comments on commit 335e6ad

Please sign in to comment.