Skip to content

Commit

Permalink
Check if feature exist when predict target in NaiveBayes (#327)
Browse files Browse the repository at this point in the history
* Check if feature exist when predict target in NaiveBayes

* Fix typo
  • Loading branch information
akondas authored Nov 7, 2018
1 parent 18c36b9 commit d30c212
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/Classification/NaiveBayes.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Phpml\Classification;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
use Phpml\Math\Statistic\Mean;
Expand Down Expand Up @@ -137,6 +138,10 @@ private function calculateStatistics(string $label, array $samples): void
*/
private function sampleProbability(array $sample, int $feature, string $label): float
{
if (!isset($sample[$feature])) {
throw new InvalidArgumentException('Missing feature. All samples must have equal number of features');
}

$value = $sample[$feature];
if ($this->dataType[$label][$feature] == self::NOMINAL) {
if (!isset($this->discreteProb[$label][$feature][$value]) ||
Expand Down
16 changes: 16 additions & 0 deletions tests/Classification/NaiveBayesTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Phpml\Tests\Classification;

use Phpml\Classification\NaiveBayes;
use Phpml\Exception\InvalidArgumentException;
use Phpml\ModelManager;
use PHPUnit\Framework\TestCase;

Expand Down Expand Up @@ -125,4 +126,19 @@ public function testSaveAndRestoreNumericLabels(): void
self::assertEquals($classifier, $restoredClassifier);
self::assertEquals($predicted, $restoredClassifier->predict($testSamples));
}

public function testInconsistentFeaturesInSamples(): void
{
$trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]];
$trainLabels = ['1996', '1997', '1998'];

$testSamples = [[3, 1, 1], [5, 1], [4, 3, 8]];

$classifier = new NaiveBayes();
$classifier->train($trainSamples, $trainLabels);

$this->expectException(InvalidArgumentException::class);

$classifier->predict($testSamples);
}
}

0 comments on commit d30c212

Please sign in to comment.