diff --git a/gcn.py b/gcn.py index 6e04b69..542e54d 100644 --- a/gcn.py +++ b/gcn.py @@ -143,11 +143,13 @@ def __init__(self, in_channels, out_channels, gcn_type, gcn_partition=None): else: if gcn_type == 'gat': self.adj_available = False - if gcn_type in ['normal', 'cheb', 'graph']: + if gcn_type in ['normal', 'cheb']: self.batch_training = True self.kwargs['node_dim'] = 1 if gcn_type == 'cheb': self.kwargs['K'] = 3 + if gcn_type == 'sage': + self.kwargs['concat'] = True GCNCell = {'normal':PyG.GCNConv, 'cheb':PyG.ChebConv,