forked from yongqyu/MolGAN-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
90 lines (70 loc) · 3.37 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from layers import GraphConvolution, GraphAggregation
class ResidualBlock(nn.Module):
"""Residual Block with instance normalization."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
def forward(self, x):
return x + self.main(x)
class Generator(nn.Module):
"""Generator network."""
def __init__(self, conv_dims, z_dim, vertexes, edges, nodes, dropout):
super(Generator, self).__init__()
self.vertexes = vertexes
self.edges = edges
self.nodes = nodes
layers = []
for c0, c1 in zip([z_dim]+conv_dims[:-1], conv_dims):
layers.append(nn.Linear(c0, c1))
layers.append(nn.Tanh())
layers.append(nn.Dropout(p=dropout, inplace=True))
self.layers = nn.Sequential(*layers)
self.edges_layer = nn.Linear(conv_dims[-1], edges * vertexes * vertexes)
self.nodes_layer = nn.Linear(conv_dims[-1], vertexes * nodes)
self.dropoout = nn.Dropout(p=dropout)
def forward(self, x):
output = self.layers(x)
edges_logits = self.edges_layer(output)\
.view(-1,self.edges,self.vertexes,self.vertexes)
edges_logits = (edges_logits + edges_logits.permute(0,1,3,2))/2
edges_logits = self.dropoout(edges_logits.permute(0,2,3,1))
nodes_logits = self.nodes_layer(output)
nodes_logits = self.dropoout(nodes_logits.view(-1,self.vertexes,self.nodes))
return edges_logits, nodes_logits
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, conv_dim, m_dim, b_dim, dropout):
super(Discriminator, self).__init__()
graph_conv_dim, aux_dim, linear_dim = conv_dim
# discriminator
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout)
self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, b_dim, dropout)
# multi dense layer
layers = []
for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
layers.append(nn.Linear(c0,c1))
layers.append(nn.Dropout(dropout))
self.linear_layer = nn.Sequential(*layers)
self.output_layer = nn.Linear(linear_dim[-1], 1)
def forward(self, adj, hidden, node, activatation=None):
adj = adj[:,:,:,1:].permute(0,3,1,2)
annotations = torch.cat((hidden, node), -1) if hidden is not None else node
h = self.gcn_layer(annotations, adj)
annotations = torch.cat((h, hidden, node) if hidden is not None\
else (h, node), -1)
h = self.agg_layer(annotations, torch.tanh)
h = self.linear_layer(h)
# Need to implemente batch discriminator #
##########################################
output = self.output_layer(h)
output = activatation(output) if activatation is not None else output
return output, h