Skip to content

Commit

Permalink
Choose averaging method in classification report (#205)
Browse files Browse the repository at this point in the history
* Fix testcases of ClassificationReport

* Fix averaging method in ClassificationReport

* Fix divided by zero if labels are empty

* Fix calculation of f1score

* Add averaging methods (not completed)

* Implement weighted average method

* Extract counts to properties

* Fix default to macro average

* Implement micro average method

* Fix style

* Update docs

* Fix styles
  • Loading branch information
y-uti authored and akondas committed Jan 29, 2018
1 parent ba7114a commit 554c86a
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 31 deletions.
7 changes: 7 additions & 0 deletions docs/machine-learning/metric/classification-report.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ $predictedLabels = ['cat', 'cat', 'bird', 'bird', 'ant'];
$report = new ClassificationReport($actualLabels, $predictedLabels);
```

Optionally you can provide the following parameter:

* $average - (int) averaging method for multi-class classification
* `ClassificationReport::MICRO_AVERAGE` = 1
* `ClassificationReport::MACRO_AVERAGE` = 2 (default)
* `ClassificationReport::WEIGHTED_AVERAGE` = 3

### Metrics

After creating the report you can draw its individual metrics:
Expand Down
137 changes: 112 additions & 25 deletions src/Phpml/Metric/ClassificationReport.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@

namespace Phpml\Metric;

use Phpml\Exception\InvalidArgumentException;

class ClassificationReport
{
public const MICRO_AVERAGE = 1;

public const MACRO_AVERAGE = 2;

public const WEIGHTED_AVERAGE = 3;

/**
* @var array
*/
private $precision = [];
private $truePositive = [];

/**
* @var array
*/
private $recall = [];
private $falsePositive = [];

/**
* @var array
*/
private $f1score = [];
private $falseNegative = [];

/**
* @var array
Expand All @@ -29,26 +37,33 @@ class ClassificationReport
/**
* @var array
*/
private $average = [];
private $precision = [];

public function __construct(array $actualLabels, array $predictedLabels)
{
$truePositive = $falsePositive = $falseNegative = $this->support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
/**
* @var array
*/
private $recall = [];

foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];
++$this->support[$actual];
/**
* @var array
*/
private $f1score = [];

if ($actual === $predicted) {
++$truePositive[$actual];
} else {
++$falsePositive[$predicted];
++$falseNegative[$actual];
}
/**
* @var array
*/
private $average = [];

public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
{
$averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
if (!in_array($average, $averagingMethods)) {
throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
}

$this->computeMetrics($truePositive, $falsePositive, $falseNegative);
$this->computeAverage();
$this->aggregateClassificationResults($actualLabels, $predictedLabels);
$this->computeMetrics();
$this->computeAverage($average);
}

public function getPrecision(): array
Expand Down Expand Up @@ -76,20 +91,73 @@ public function getAverage(): array
return $this->average;
}

private function computeMetrics(array $truePositive, array $falsePositive, array $falseNegative): void
private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
{
foreach ($truePositive as $label => $tp) {
$this->precision[$label] = $this->computePrecision($tp, $falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $falseNegative[$label]);
$truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);

foreach ($actualLabels as $index => $actual) {
$predicted = $predictedLabels[$index];
++$support[$actual];

if ($actual === $predicted) {
++$truePositive[$actual];
} else {
++$falsePositive[$predicted];
++$falseNegative[$actual];
}
}

$this->truePositive = $truePositive;
$this->falsePositive = $falsePositive;
$this->falseNegative = $falseNegative;
$this->support = $support;
}

private function computeMetrics(): void
{
foreach ($this->truePositive as $label => $tp) {
$this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
$this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
$this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
}
}

private function computeAverage(): void
private function computeAverage(int $average): void
{
switch ($average) {
case self::MICRO_AVERAGE:
$this->computeMicroAverage();

return;
case self::MACRO_AVERAGE:
$this->computeMacroAverage();

return;
case self::WEIGHTED_AVERAGE:
$this->computeWeightedAverage();

return;
}
}

private function computeMicroAverage(): void
{
$truePositive = array_sum($this->truePositive);
$falsePositive = array_sum($this->falsePositive);
$falseNegative = array_sum($this->falseNegative);

$precision = $this->computePrecision($truePositive, $falsePositive);
$recall = $this->computeRecall($truePositive, $falseNegative);
$f1score = $this->computeF1Score((float) $precision, (float) $recall);

$this->average = compact('precision', 'recall', 'f1score');
}

private function computeMacroAverage(): void
{
foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = array_filter($this->{$metric});
if (empty($values)) {
$values = $this->{$metric};
if (count($values) == 0) {
$this->average[$metric] = 0.0;

continue;
Expand All @@ -99,6 +167,25 @@ private function computeAverage(): void
}
}

private function computeWeightedAverage(): void
{
foreach (['precision', 'recall', 'f1score'] as $metric) {
$values = $this->{$metric};
if (count($values) == 0) {
$this->average[$metric] = 0.0;

continue;
}

$sum = 0;
foreach ($values as $i => $value) {
$sum += $value * $this->support[$i];
}

$this->average[$metric] = $sum / array_sum($this->support);
}
}

/**
* @return float|string
*/
Expand Down
86 changes: 80 additions & 6 deletions tests/Phpml/Metric/ClassificationReportTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Phpml\Tests\Metric;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Metric\ClassificationReport;
use PHPUnit\Framework\TestCase;

Expand Down Expand Up @@ -36,10 +37,12 @@ public function testClassificationReportGenerateWithStringLabels(): void
'ant' => 1,
'bird' => 3,
];

// ClassificationReport uses macro-averaging as default
$average = [
'precision' => 0.75,
'recall' => 0.83,
'f1score' => 0.73,
'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
];

$this->assertEquals($precision, $report->getPrecision(), '', 0.01);
Expand Down Expand Up @@ -77,9 +80,9 @@ public function testClassificationReportGenerateWithNumericLabels(): void
2 => 3,
];
$average = [
'precision' => 0.75,
'recall' => 0.83,
'f1score' => 0.73,
'precision' => 0.5,
'recall' => 0.56,
'f1score' => 0.49,
];

$this->assertEquals($precision, $report->getPrecision(), '', 0.01);
Expand All @@ -89,6 +92,63 @@ public function testClassificationReportGenerateWithNumericLabels(): void
$this->assertEquals($average, $report->getAverage(), '', 0.01);
}

public function testClassificationReportAverageOutOfRange(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];

$this->expectException(InvalidArgumentException::class);
$report = new ClassificationReport($labels, $predicted, 0);
}

public function testClassificationReportMicroAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];

$report = new ClassificationReport($labels, $predicted, ClassificationReport::MICRO_AVERAGE);

$average = [
'precision' => 0.6, // TP / (TP + FP) = (1 + 0 + 2) / (2 + 1 + 2) = 3/5
'recall' => 0.6, // TP / (TP + FN) = (1 + 0 + 2) / (1 + 1 + 3) = 3/5
'f1score' => 0.6, // Harmonic mean of precision and recall
];

$this->assertEquals($average, $report->getAverage(), '', 0.01);
}

public function testClassificationReportMacroAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];

$report = new ClassificationReport($labels, $predicted, ClassificationReport::MACRO_AVERAGE);

$average = [
'precision' => 0.5, // (1/2 + 0 + 1) / 3 = 1/2
'recall' => 0.56, // (1 + 0 + 2/3) / 3 = 5/9
'f1score' => 0.49, // (2/3 + 0 + 4/5) / 3 = 22/45
];

$this->assertEquals($average, $report->getAverage(), '', 0.01);
}

public function testClassificationReportWeightedAverage(): void
{
$labels = ['cat', 'ant', 'bird', 'bird', 'bird'];
$predicted = ['cat', 'cat', 'bird', 'bird', 'ant'];

$report = new ClassificationReport($labels, $predicted, ClassificationReport::WEIGHTED_AVERAGE);

$average = [
'precision' => 0.7, // (1/2 * 1 + 0 * 1 + 1 * 3) / 5 = 7/10
'recall' => 0.6, // (1 * 1 + 0 * 1 + 2/3 * 3) / 5 = 3/5
'f1score' => 0.61, // (2/3 * 1 + 0 * 1 + 4/5 * 3) / 5 = 46/75
];

$this->assertEquals($average, $report->getAverage(), '', 0.01);
}

public function testPreventDivideByZeroWhenTruePositiveAndFalsePositiveSumEqualsZero(): void
{
$labels = [1, 2];
Expand Down Expand Up @@ -129,4 +189,18 @@ public function testPreventDividedByZeroWhenPredictedLabelsAllNotMatch(): void
'f1score' => 0,
], $report->getAverage(), '', 0.01);
}

public function testPreventDividedByZeroWhenLabelsAreEmpty(): void
{
$labels = [];
$predicted = [];

$report = new ClassificationReport($labels, $predicted);

$this->assertEquals([
'precision' => 0,
'recall' => 0,
'f1score' => 0,
], $report->getAverage(), '', 0.01);
}
}

0 comments on commit 554c86a

Please sign in to comment.