Skip to content

Commit

Permalink
WMCM slides
Browse files Browse the repository at this point in the history
  • Loading branch information
William Fiset authored and William Fiset committed Jun 5, 2021
1 parent 5501f3a commit 5503735
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 67 deletions.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ public class WeightedMaximumCardinalityMatchingRecursive implements MwpmInterfac
private double minWeightCost;
private int[] matching;

// The cost matrix should be a symmetric (i.e cost[i][j] = cost[j][i])
public WeightedMaximumCardinalityMatchingRecursive(double[][] cost) {
// The cost matrix should be a symmetric (i.e cost[i][j] = cost[j][i]) and have a cost of `null`
// between nodes i and j if no edge exists between those two nodes.
public WeightedMaximumCardinalityMatchingRecursive(Double[][] cost) {
if (cost == null) throw new IllegalArgumentException("Input cannot be null");
n = cost.length;
if (n <= 1) throw new IllegalArgumentException("Invalid matrix size: " + n);
Expand All @@ -46,16 +47,17 @@ public WeightedMaximumCardinalityMatchingRecursive(double[][] cost) {
// Sets the cost matrix. If the number of nodes in the graph is odd, add an artificial
// node that connects to every other node with a cost of infinity. This will make it easy
// to find a perfect matching and remove in the artificial node in the end.
private void setCostMatrix(double[][] inputMatrix) {
double[][] newCostMatrix = inputMatrix;
private void setCostMatrix(Double[][] inputMatrix) {
double[][] newCostMatrix = null;
if (n % 2 != 0) {
isOdd = true;
newCostMatrix = new double[n + 1][n + 1];
double maxValue = Double.MIN_VALUE;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
newCostMatrix[i][j] = inputMatrix[i][j];
maxValue = Math.max(maxValue, inputMatrix[i][j]);
double edgeCost = inputMatrix[i][j] == null ? INF : inputMatrix[i][j];
newCostMatrix[i][j] = edgeCost;
maxValue = Math.max(maxValue, edgeCost);
}
}
if (maxValue > INF) {
Expand All @@ -68,6 +70,14 @@ private void setCostMatrix(double[][] inputMatrix) {
newCostMatrix[n][n] = 0;
artificialNodeId = n;
n++;
} else {
newCostMatrix = new double[n][n];
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
double edgeCost = inputMatrix[i][j] == null ? INF : inputMatrix[i][j];
newCostMatrix[i][j] = edgeCost;
}
}
}
this.cost = newCostMatrix;
}
Expand Down Expand Up @@ -105,9 +115,12 @@ private void solve() {
int[] history = new int[1 << n];
minWeightCost = f(END_STATE, dp, history);
// Remove the cost of the artificial node
if (isOdd) {
minWeightCost -= INF;
}
// if (isOdd) {
// minWeightCost -= INF;
// }
// TODO(william): This isn't very elegant, or error proof...
// Remove cost of edges that were matched, but which don't actually exist.
while (minWeightCost > INF) minWeightCost -= INF;
reconstructMatching(history);
solved = true;
}
Expand Down Expand Up @@ -153,6 +166,8 @@ private void reconstructMatching(int[] history) {
int[] map = new int[n];
int[] leftNodes = new int[n / 2];

int matchingSize = 0;

// Reconstruct the matching of pairs of nodes working backwards through computed states.
for (int i = 0, state = END_STATE; state != 0; state = history[state]) {
// Isolate the pair used by xoring the state with the state used to generate it.
Expand All @@ -163,13 +178,18 @@ private void reconstructMatching(int[] history) {

leftNodes[i++] = leftNode;
map[leftNode] = rightNode;

if (cost[leftNode][rightNode] != INF) matchingSize++;
}

matchingSize = matchingSize * 2;

// Sort the left nodes in ascending order.
java.util.Arrays.sort(leftNodes);

int m = isOdd ? n - 2 : n;
matching = new int[m];
// int m = isOdd ? n - 2 : n;
// matching = new int[m];
matching = new int[matchingSize];

for (int i = 0, j = 0; i < n / 2; i++) {
int leftNode = leftNodes[i];
Expand All @@ -178,9 +198,12 @@ private void reconstructMatching(int[] history) {
if (isOdd && (leftNode == artificialNodeId || rightNode == artificialNodeId)) {
continue;
}
matching[2 * j] = leftNode;
matching[2 * j + 1] = rightNode;
j++;
// Only match edges which actually exist
if (cost[leftNode][rightNode] != INF) {
matching[2 * j] = leftNode;
matching[2 * j + 1] = rightNode;
j++;
}
}
}

Expand All @@ -203,13 +226,13 @@ public static void main(String[] args) {

private static void test() {
// mwpm is expected to be between nodes: 0 & 5, 1 & 2, 3 & 4
double[][] costMatrix = {
{0, 9, 9, 9, 9, 1},
{9, 0, 1, 9, 9, 9},
{9, 1, 0, 9, 9, 9},
{9, 9, 9, 0, 1, 9},
{9, 9, 9, 1, 0, 9},
{1, 9, 9, 9, 9, 0},
Double[][] costMatrix = {
{0.0, 9.0, 9.0, 9.0, 9.0, 1.0},
{9.0, 0.0, 1.0, 9.0, 9.0, 9.0},
{9.0, 1.0, 0.0, 9.0, 9.0, 9.0},
{9.0, 9.0, 9.0, 0.0, 1.0, 9.0},
{9.0, 9.0, 9.0, 1.0, 0.0, 9.0},
{1.0, 9.0, 9.0, 9.0, 9.0, 0.0},
};

WeightedMaximumCardinalityMatchingRecursive mwpm =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
public class WeightedMaximumCardinalityMatchingTest {

static final int LOOPS = 50;
static final int INF = 1000000;

static class BruteForceMwpm {
private int n;
private double[][] matrix;
private Double[][] matrix;
private double minWeightMatching = Double.POSITIVE_INFINITY;

public BruteForceMwpm(double[][] matrix) {
public BruteForceMwpm(Double[][] matrix) {
this.matrix = matrix;
this.n = matrix.length;
}
Expand Down Expand Up @@ -75,33 +74,32 @@ private static void swap(int[] sequence, int i, int j) {
}
}

private static MwpmInterface[] getImplementations(double[][] costMatrix) {
return new MwpmInterface[] {
new WeightedMaximumCardinalityMatchingRecursive(costMatrix),
new WeightedMaximumCardinalityMatchingIterative(costMatrix)
private static MwpmInterface[] getImplementations(Double[][] costMatrix) {
return new MwpmInterface[] {new WeightedMaximumCardinalityMatchingRecursive(costMatrix)
// new WeightedMaximumCardinalityMatchingIterative(costMatrix)
};
}

private static double[][] createEmptyMatrix(int n) {
double[][] costMatrix = new double[n][n];
private static Double[][] createEmptyMatrix(int n) {
Double[][] costMatrix = new Double[n][n];
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (i == j) continue;
costMatrix[i][j] = INF;
costMatrix[i][j] = null;
}
}
return costMatrix;
}

private static void addUndirectedWeightedEdge(double[][] g, int from, int to, double weight) {
private static void addUndirectedWeightedEdge(Double[][] g, int from, int to, double weight) {
g[from][to] = weight;
g[to][from] = weight;
}

@Test
public void testSmallGraph_oddSize() {
int n = 5;
double[][] g = createEmptyMatrix(n);
Double[][] g = createEmptyMatrix(n);
// 0, 2; 3, 4
addUndirectedWeightedEdge(g, 0, 1, 8);
addUndirectedWeightedEdge(g, 0, 2, 1);
Expand All @@ -122,9 +120,9 @@ public void testSmallGraph_oddSize() {
@Test
public void testSmallestMatrix1() {
// nodes 0 & 1 make the mwpm
double[][] costMatrix = {
{0, 1},
{1, 0},
Double[][] costMatrix = {
{0.0, 1.0},
{1.0, 0.0},
};
MwpmInterface[] impls = getImplementations(costMatrix);
for (MwpmInterface mwpm : impls) {
Expand All @@ -140,11 +138,11 @@ public void testSmallestMatrix1() {
@Test
public void testSmallMatrix1() {
// nodes 0 & 2 and 1 & 3 make the mwpm
double[][] costMatrix = {
{0, 2, 1, 2},
{2, 0, 2, 1},
{1, 2, 0, 2},
{2, 1, 2, 0},
Double[][] costMatrix = {
{0.0, 2.0, 1.0, 2.0},
{2.0, 0.0, 2.0, 1.0},
{1.0, 2.0, 0.0, 2.0},
{2.0, 1.0, 2.0, 0.0},
};

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -161,11 +159,11 @@ public void testSmallMatrix1() {
@Test
public void testSmallMatrix2() {
// nodes 0 & 1 and 2 & 3 make the mwpm
double[][] costMatrix = {
{0, 1, 2, 2},
{1, 0, 2, 2},
{2, 2, 0, 1},
{2, 2, 1, 0},
Double[][] costMatrix = {
{0.0, 1.0, 2.0, 2.0},
{1.0, 0.0, 2.0, 2.0},
{2.0, 2.0, 0.0, 1.0},
{2.0, 2.0, 1.0, 0.0},
};

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -182,13 +180,13 @@ public void testSmallMatrix2() {
@Test
public void testMediumMatrix1() {
// mwpm between 0 & 5, 1 & 2, 3 & 4
double[][] costMatrix = {
{0, 9, 9, 9, 9, 1},
{9, 0, 1, 9, 9, 9},
{9, 1, 0, 9, 9, 9},
{9, 9, 9, 0, 1, 9},
{9, 9, 9, 1, 0, 9},
{1, 9, 9, 9, 9, 0},
Double[][] costMatrix = {
{0.0, 9.0, 9.0, 9.0, 9.0, 1.0},
{9.0, 0.0, 1.0, 9.0, 9.0, 9.0},
{9.0, 1.0, 0.0, 9.0, 9.0, 9.0},
{9.0, 9.0, 9.0, 0.0, 1.0, 9.0},
{9.0, 9.0, 9.0, 1.0, 0.0, 9.0},
{1.0, 9.0, 9.0, 9.0, 9.0, 0.0},
};

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -205,13 +203,13 @@ public void testMediumMatrix1() {
@Test
public void testMediumMatrix2() {
// mwpm between 0 & 1, 2 & 4, 3 & 5
double[][] costMatrix = {
{0, 1, 9, 9, 9, 9},
{1, 0, 9, 9, 9, 9},
{9, 9, 0, 9, 1, 9},
{9, 9, 9, 0, 9, 1},
{9, 9, 1, 9, 0, 9},
{9, 9, 9, 1, 9, 0},
Double[][] costMatrix = {
{0.0, 1.0, 9.0, 9.0, 9.0, 9.0},
{1.0, 0.0, 9.0, 9.0, 9.0, 9.0},
{9.0, 9.0, 0.0, 9.0, 1.0, 9.0},
{9.0, 9.0, 9.0, 0.0, 9.0, 1.0},
{9.0, 9.0, 1.0, 9.0, 0.0, 9.0},
{9.0, 9.0, 9.0, 1.0, 9.0, 0.0},
};

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -228,11 +226,11 @@ public void testMediumMatrix2() {
@Test
public void testMediumGraph_evenSize_fromSlides() {
int n = 6;
double[][] g = createEmptyMatrix(n);
Double[][] g = createEmptyMatrix(n);

addUndirectedWeightedEdge(g, 0, 1, 7);
addUndirectedWeightedEdge(g, 0, 2, 6);
addUndirectedWeightedEdge(g, 0, 4, 11);
addUndirectedWeightedEdge(g, 0, 4, -1);
addUndirectedWeightedEdge(g, 1, 3, 1);
addUndirectedWeightedEdge(g, 1, 4, 3);
addUndirectedWeightedEdge(g, 1, 5, 5);
Expand All @@ -250,11 +248,33 @@ public void testMediumGraph_evenSize_fromSlides() {
assertThat(matching).isEqualTo(expectedMatching);
}

@Test
public void testMediumGraph_evenSize_nonPerfectMatchingFromSlides() {
int n = 6;
Double[][] g = createEmptyMatrix(n);

addUndirectedWeightedEdge(g, 0, 1, 6);
addUndirectedWeightedEdge(g, 1, 2, 7);
addUndirectedWeightedEdge(g, 1, 5, 8);
addUndirectedWeightedEdge(g, 1, 4, 9);
addUndirectedWeightedEdge(g, 1, 3, 10);
addUndirectedWeightedEdge(g, 3, 4, 11);

MwpmInterface mwpm = new WeightedMaximumCardinalityMatchingRecursive(g);
double cost = mwpm.getMinWeightCost();
assertThat(cost).isEqualTo(17);

int[] matching = mwpm.getMatching();

int[] expectedMatching = {0, 1, 3, 4};
assertThat(matching).isEqualTo(expectedMatching);
}

@Test
public void testMatchingOutputsUniqueNodes() {
for (int loop = 0; loop < LOOPS; loop++) {
int n = Math.max(1, (int) (Math.random() * 11)) * 2; // n is either 2,4,6,8,10,12,14,16,18,20
double[][] costMatrix = new double[n][n];
Double[][] costMatrix = new Double[n][n];
randomFillSymmetricMatrix(costMatrix, 100);

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -274,7 +294,7 @@ public void testMatchingOutputsUniqueNodes() {
public void testMatchingAndCostAreConsistent() {
for (int loop = 0; loop < LOOPS; loop++) {
int n = Math.max(1, (int) (Math.random() * 11)) * 2; // n is either 2,4,6,8,10,12,14,16,18,20
double[][] costMatrix = new double[n][n];
Double[][] costMatrix = new Double[n][n];
randomFillSymmetricMatrix(costMatrix, 100);

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -295,7 +315,7 @@ public void testMatchingAndCostAreConsistent() {
public void testAgainstBruteForce_largeValues() {
for (int loop = 0; loop < LOOPS; loop++) {
int n = Math.max(1, (int) (Math.random() * 6)) * 2; // n is either 2,4,6,8, or 10
double[][] costMatrix = new double[n][n];
Double[][] costMatrix = new Double[n][n];
randomFillSymmetricMatrix(costMatrix, /*maxValue=*/ 10000);

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -313,7 +333,7 @@ public void testAgainstBruteForce_largeValues() {
public void testAgainstBruteForce_smallValues() {
for (int loop = 0; loop < LOOPS; loop++) {
int n = Math.max(1, (int) (Math.random() * 6)) * 2; // n is either 2,4,6,8, or 10
double[][] costMatrix = new double[n][n];
Double[][] costMatrix = new Double[n][n];
randomFillSymmetricMatrix(costMatrix, /*maxValue=*/ 3);

MwpmInterface[] impls = getImplementations(costMatrix);
Expand All @@ -328,7 +348,7 @@ public void testAgainstBruteForce_smallValues() {
}
}

public void randomFillSymmetricMatrix(double[][] dist, int maxValue) {
public void randomFillSymmetricMatrix(Double[][] dist, int maxValue) {
for (int i = 0; i < dist.length; i++) {
for (int j = i + 1; j < dist.length; j++) {
double val = (int) (Math.random() * maxValue);
Expand Down

0 comments on commit 5503735

Please sign in to comment.