-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsortingnetwork.py
122 lines (103 loc) · 4.33 KB
/
sortingnetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# This file is part of DEAP.
#
# DEAP is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of
# the License, or (at your option) any later version.
#
# DEAP is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with DEAP. If not, see <http://www.gnu.org/licenses/>.
from itertools import product
class SortingNetwork(list):
"""Sorting network class.
From Wikipedia : A sorting network is an abstract mathematical model
of a network of wires and comparator modules that is used to sort a
sequence of numbers. Each comparator connects two wires and sort the
values by outputting the smaller value to one wire, and a larger
value to the other.
"""
def __init__(self, dimension, connectors = []):
self.dimension = dimension
for wire1, wire2 in connectors:
self.addConnector(wire1, wire2)
def addConnector(self, wire1, wire2):
"""Add a connector between wire1 and wire2 in the network."""
if wire1 == wire2:
return
if wire1 > wire2:
wire1, wire2 = wire2, wire1
index = 0
for level in reversed(self):
if self.checkConflict(level, wire1, wire2):
break
index -= 1
if index == 0:
self.append([(wire1, wire2)])
else:
self[index].append((wire1, wire2))
def checkConflict(self, level, wire1, wire2):
"""Check if a connection between `wire1` and `wire2` can be
added on this `level`."""
for wires in level:
if wires[1] >= wire1 and wires[0] <= wire2:
return True
def sort(self, values):
"""Sort the values in-place based on the connectors in the network."""
for level in self:
for wire1, wire2 in level:
if values[wire1] > values[wire2]:
values[wire1], values[wire2] = values[wire2], values[wire1]
def assess(self, cases=None):
"""Try to sort the **cases** using the network, return the number of
misses. If **cases** is None, test all possible cases according to
the network dimensionality.
"""
if cases is None:
cases = product((0, 1), repeat=self.dimension)
misses = 0
ordered = [[0]*(self.dimension-i) + [1]*i for i in range(self.dimension+1)]
for sequence in cases:
sequence = list(sequence)
self.sort(sequence)
misses += (sequence != ordered[sum(sequence)])
return misses
def draw(self):
"""Return an ASCII representation of the network."""
str_wires = [["-"]*7 * self.depth]
str_wires[0][0] = "0"
str_wires[0][1] = " o"
str_spaces = []
for i in range(1, self.dimension):
str_wires.append(["-"]*7 * self.depth)
str_spaces.append([" "]*7 * self.depth)
str_wires[i][0] = str(i)
str_wires[i][1] = " o"
for index, level in enumerate(self):
for wire1, wire2 in level:
str_wires[wire1][(index+1)*6] = "x"
str_wires[wire2][(index+1)*6] = "x"
for i in range(wire1, wire2):
str_spaces[i][(index+1)*6+1] = "|"
for i in range(wire1+1, wire2):
str_wires[i][(index+1)*6] = "|"
network_draw = "".join(str_wires[0])
for line, space in zip(str_wires[1:], str_spaces):
network_draw += "\n"
network_draw += "".join(space)
network_draw += "\n"
network_draw += "".join(line)
return network_draw
@property
def depth(self):
"""Return the number of parallel steps that it takes to sort any input.
"""
return len(self)
@property
def length(self):
"""Return the number of comparison-swap used."""
return sum(len(level) for level in self)