Skip to content

Commit

Permalink
remove old medge implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
StefReck committed Apr 20, 2021
1 parent ccb4255 commit c0012a5
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 78 deletions.
1 change: 0 additions & 1 deletion examples/model_files/graph_disjoint_edgeconv.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# ParticleNet, using my custom DisjointEdgeConv block
# in the config, its mandatory to use the fixed_batchsize option
# use a sample_modifier to produce the 3 given inputs (e.g. GraphEdgeConv)

[model]
Expand Down
25 changes: 0 additions & 25 deletions examples/model_files/graph_medgeconv.toml

This file was deleted.

38 changes: 0 additions & 38 deletions orcanet/builder_util/layer_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,44 +224,6 @@ def __call__(self, inputs):
return x


@register
class MEdgeConvBlock:
""" EdgeConv as defined in ParticleNet, see github.com/StefReck/MEdgeConv """
def __init__(self, units,
next_neighbors=16,
shortcut=True,
batchnorm_for_nodes=False,
pooling=False,
kernel_initializer="glorot_uniform",
activation="relu"):
self.units = units
self.next_neighbors = next_neighbors
self.batchnorm_for_nodes = batchnorm_for_nodes
self.shortcut = shortcut
self.pooling = pooling
self.kernel_initializer = kernel_initializer
self.activation = activation

def __call__(self, x):
nodes, is_valid, coordinates = x

if self.batchnorm_for_nodes:
nodes = layers.BatchNormalization()(nodes)

nodes = medgeconv.EdgeConv(
units=self.units,
next_neighbors=self.next_neighbors,
kernel_initializer=self.kernel_initializer,
activation=self.activation,
shortcut=self.shortcut,
)((nodes, is_valid, coordinates))

if self.pooling:
return medgeconv.GlobalAvgValidPooling()((nodes, is_valid))
else:
return nodes, is_valid, nodes


@register
class ResnetBlock:
"""
Expand Down
14 changes: 0 additions & 14 deletions orcanet/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,6 @@ def test_lstm(self):
self.assertSequenceEqual(model.output_shape, (None, 3))
self.assertEqual(model.count_params(), 11321)

def test_medgeconv(self):
toml_file = "graph_medgeconv.toml"

model_file = os.path.join(self.example_dir, toml_file)
mb = orcanet.model_builder.ModelBuilder(model_file)
orga = self.get_orga(dims="graph")
orga.cfg.batchsize = 64
orga.cfg.fixed_batchsize = True
model = mb.build(orga)

self.assertSequenceEqual(model.output_shape, (64, 3))
self.assertEqual(model.count_params(), 304223)
self.assertEqual(len(model.layers), 87)

def test_disjoint_edgeconv(self):
toml_file = "graph_disjoint_edgeconv.toml"

Expand Down

0 comments on commit c0012a5

Please sign in to comment.