forked from jorgecasas/php-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\Metric; | ||
|
||
use Phpml\Exception\InvalidArgumentException; | ||
use Phpml\Math\Statistic\Correlation; | ||
use Phpml\Math\Statistic\Mean; | ||
|
||
final class Regression | ||
{ | ||
public static function meanSquaredError(array $targets, array $predictions): float | ||
{ | ||
self::assertCountEquals($targets, $predictions); | ||
|
||
$errors = []; | ||
foreach ($targets as $index => $target) { | ||
$errors[] = (($target - $predictions[$index]) ** 2); | ||
} | ||
|
||
return Mean::arithmetic($errors); | ||
} | ||
|
||
public static function meanSquaredLogarithmicError(array $targets, array $predictions): float | ||
{ | ||
self::assertCountEquals($targets, $predictions); | ||
|
||
$errors = []; | ||
foreach ($targets as $index => $target) { | ||
$errors[] = (log(1 + $target) - log(1 + $predictions[$index])) ** 2; | ||
} | ||
|
||
return Mean::arithmetic($errors); | ||
} | ||
|
||
public static function meanAbsoluteError(array $targets, array $predictions): float | ||
{ | ||
self::assertCountEquals($targets, $predictions); | ||
|
||
$errors = []; | ||
foreach ($targets as $index => $target) { | ||
$errors[] = abs($target - $predictions[$index]); | ||
} | ||
|
||
return Mean::arithmetic($errors); | ||
} | ||
|
||
public static function medianAbsoluteError(array $targets, array $predictions): float | ||
{ | ||
self::assertCountEquals($targets, $predictions); | ||
|
||
$errors = []; | ||
foreach ($targets as $index => $target) { | ||
$errors[] = abs($target - $predictions[$index]); | ||
} | ||
|
||
return (float) Mean::median($errors); | ||
} | ||
|
||
public static function r2Score(array $targets, array $predictions): float | ||
{ | ||
self::assertCountEquals($targets, $predictions); | ||
|
||
return Correlation::pearson($targets, $predictions) ** 2; | ||
} | ||
|
||
public static function maxError(array $targets, array $predictions): float | ||
{ | ||
self::assertCountEquals($targets, $predictions); | ||
|
||
$errors = []; | ||
foreach ($targets as $index => $target) { | ||
$errors[] = abs($target - $predictions[$index]); | ||
} | ||
|
||
return (float) max($errors); | ||
} | ||
|
||
private static function assertCountEquals(array &$targets, array &$predictions): void | ||
{ | ||
if (count($targets) !== count($predictions)) { | ||
throw new InvalidArgumentException('Targets count must be equal with predictions count'); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\Tests\Metric; | ||
|
||
use Phpml\Exception\InvalidArgumentException; | ||
use Phpml\Metric\Regression; | ||
use PHPUnit\Framework\TestCase; | ||
|
||
final class RegressionTest extends TestCase | ||
{ | ||
public function testMeanSquaredError(): void | ||
{ | ||
self::assertEquals(6.08, Regression::meanSquaredError( | ||
[41, 45, 49, 47, 44], | ||
[43.6, 44.4, 45.2, 46, 46.8] | ||
)); | ||
|
||
self::assertEquals(0.375, Regression::meanSquaredError( | ||
[3, -0.5, 2, 7], | ||
[2.5, 0.0, 2, 8] | ||
)); | ||
} | ||
|
||
public function testR2Score(): void | ||
{ | ||
self::assertEqualsWithDelta(0.1739, Regression::r2Score( | ||
[41, 45, 49, 47, 44], | ||
[43.6, 44.4, 45.2, 46, 46.8] | ||
), 0.0001); | ||
} | ||
|
||
public function testMaxError(): void | ||
{ | ||
self::assertEquals(1, Regression::maxError([3, 2, 7, 1], [4, 2, 7, 1])); | ||
|
||
// test absolute value | ||
self::assertEquals(5, Regression::maxError([-10, 2, 7, 1], [-5, 2, 7, 1])); | ||
} | ||
|
||
public function testMeanAbsoluteError(): void | ||
{ | ||
self::assertEquals(0.5, Regression::meanAbsoluteError([3, -0.5, 2, 7], [2.5, 0.0, 2, 8])); | ||
} | ||
|
||
public function testMeanSquaredLogarithmicError(): void | ||
{ | ||
self::assertEqualsWithDelta(0.039, Regression::meanSquaredLogarithmicError( | ||
[3, 5, 2.5, 7], | ||
[2.5, 5, 4, 8] | ||
), 0.001); | ||
} | ||
|
||
public function testMedianAbsoluteError(): void | ||
{ | ||
self::assertEquals(0.5, Regression::medianAbsoluteError( | ||
[3, -0.5, 2, 7], | ||
[2.5, 0.0, 2, 8] | ||
)); | ||
} | ||
|
||
public function testMeanSquaredErrorInvalidCount(): void | ||
{ | ||
self::expectException(InvalidArgumentException::class); | ||
|
||
Regression::meanSquaredError([1], [1, 2]); | ||
} | ||
|
||
public function testR2ScoreInvalidCount(): void | ||
{ | ||
self::expectException(InvalidArgumentException::class); | ||
|
||
Regression::r2Score([1], [1, 2]); | ||
} | ||
|
||
public function testMaxErrorInvalidCount(): void | ||
{ | ||
self::expectException(InvalidArgumentException::class); | ||
|
||
Regression::r2Score([1], [1, 2]); | ||
} | ||
|
||
public function tesMeanAbsoluteErrorInvalidCount(): void | ||
{ | ||
self::expectException(InvalidArgumentException::class); | ||
|
||
Regression::meanAbsoluteError([1], [1, 2]); | ||
} | ||
|
||
public function tesMediaAbsoluteErrorInvalidCount(): void | ||
{ | ||
self::expectException(InvalidArgumentException::class); | ||
|
||
Regression::medianAbsoluteError([1], [1, 2]); | ||
} | ||
} |