Skip to content

Commit

Permalink
implement ConfusionMatrix metric
Browse files Browse the repository at this point in the history
  • Loading branch information
akondas committed Jul 6, 2016
1 parent cce6899 commit 6c7416a
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
71 changes: 71 additions & 0 deletions src/Phpml/Metric/ConfusionMatrix.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
<?php

declare (strict_types = 1);

namespace Phpml\Metric;

class ConfusionMatrix
{
/**
* @param array $actualLabels
* @param array $predictedLabels
* @param array $labels
*
* @return array
*/
public static function compute(array $actualLabels, array $predictedLabels, array $labels = null): array
{
$labels = $labels ? array_flip($labels) : self::getUniqueLabels($actualLabels);
$matrix = self::generateMatrixWithZeros($labels);

foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];

if (!isset($labels[$actual]) || !isset($labels[$predicted])) {
continue;
}

if ($predicted === $actual) {
$row = $column = $labels[$actual];
} else {
$row = $labels[$actual];
$column = $labels[$predicted];
}

$matrix[$row][$column] += 1;
}

return $matrix;
}

/**
* @param array $labels
*
* @return array
*/
private static function generateMatrixWithZeros(array $labels): array
{
$count = count($labels);
$matrix = [];

for ($i = 0; $i < $count; ++$i) {
$matrix[$i] = array_fill(0, $count, 0);
}

return $matrix;
}

/**
* @param array $labels
*
* @return array
*/
private static function getUniqueLabels(array $labels): array
{
$labels = array_values(array_unique($labels));
sort($labels);
$labels = array_flip($labels);

return $labels;
}
}
61 changes: 61 additions & 0 deletions tests/Phpml/Metric/ConfusionMatrixTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
<?php

declare (strict_types = 1);

namespace tests\Phpml\Metric;

use Phpml\Metric\ConfusionMatrix;

class ConfusionMatrixTest extends \PHPUnit_Framework_TestCase
{
public function testComputeConfusionMatrixOnNumericLabels()
{
$actualLabels = [2, 0, 2, 2, 0, 1];
$predictedLabels = [0, 0, 2, 2, 0, 2];

$confusionMatrix = [
[2, 0, 0],
[0, 0, 1],
[1, 0, 2],
];

$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels));
}

public function testComputeConfusionMatrixOnStringLabels()
{
$actualLabels = ['cat', 'ant', 'cat', 'cat', 'ant', 'bird'];
$predictedLabels = ['ant', 'ant', 'cat', 'cat', 'ant', 'cat'];

$confusionMatrix = [
[2, 0, 0],
[0, 0, 1],
[1, 0, 2],
];

$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels));
}

public function testComputeConfusionMatrixOnLabelsWithSubset()
{
$actualLabels = ['cat', 'ant', 'cat', 'cat', 'ant', 'bird'];
$predictedLabels = ['ant', 'ant', 'cat', 'cat', 'ant', 'cat'];
$labels = ['ant', 'bird'];

$confusionMatrix = [
[2, 0],
[0, 0],
];

$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels, $labels));

$labels = ['bird', 'ant'];

$confusionMatrix = [
[0, 0],
[0, 2],
];

$this->assertEquals($confusionMatrix, ConfusionMatrix::compute($actualLabels, $predictedLabels, $labels));
}
}

0 comments on commit 6c7416a

Please sign in to comment.