Skip to content

Commit

Permalink
Add RandomForest exception tests (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
marmichalski authored and akondas committed Mar 4, 2018
1 parent 8976047 commit 941d240
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/Classification/DecisionTree/DecisionTreeLeaf.php
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ class DecisionTreeLeaf
public $columnIndex;

/**
* @var ?DecisionTreeLeaf
* @var DecisionTreeLeaf|null
*/
public $leftLeaf;

/**
* @var ?DecisionTreeLeaf
* @var DecisionTreeLeaf|null
*/
public $rightLeaf;

Expand Down
24 changes: 11 additions & 13 deletions src/Classification/Ensemble/RandomForest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

namespace Phpml\Classification\Ensemble;

use Exception;
use Phpml\Classification\Classifier;
use Phpml\Classification\DecisionTree;
use Phpml\Exception\InvalidArgumentException;

class RandomForest extends Bagging
{
Expand Down Expand Up @@ -41,20 +41,20 @@ public function __construct(int $numClassifier = 50)
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
* features to be taken into consideration while selecting subspace of features
*
* @param mixed $ratio string or float should be given
*
* @return $this
*
* @throws \Exception
* @param string|float $ratio
*/
public function setFeatureSubsetRatio($ratio)
public function setFeatureSubsetRatio($ratio): self
{
if (!is_string($ratio) && !is_float($ratio)) {
throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
}

if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
throw new Exception('When a float given, feature subset ratio should be between 0.1 and 1.0');
throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
}

if (is_string($ratio) && $ratio != 'sqrt' && $ratio != 'log') {
throw new Exception("When a string given, feature subset ratio can only be 'sqrt' or 'log' ");
throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
}

$this->featureSubsetRatio = $ratio;
Expand All @@ -66,13 +66,11 @@ public function setFeatureSubsetRatio($ratio)
* RandomForest algorithm is usable *only* with DecisionTree
*
* @return $this
*
* @throws \Exception
*/
public function setClassifer(string $classifier, array $classifierOptions = [])
{
if ($classifier != DecisionTree::class) {
throw new Exception('RandomForest can only use DecisionTree as base classifier');
throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
}

return parent::setClassifer($classifier, $classifierOptions);
Expand Down Expand Up @@ -133,7 +131,7 @@ protected function initSingleClassifier(Classifier $classifier): Classifier
{
if (is_float($this->featureSubsetRatio)) {
$featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
} elseif ($this->featureCount == 'sqrt') {
} elseif ($this->featureSubsetRatio == 'sqrt') {
$featureCount = (int) sqrt($this->featureCount) + 1;
} else {
$featureCount = (int) log($this->featureCount, 2) + 1;
Expand Down
43 changes: 34 additions & 9 deletions tests/Classification/Ensemble/RandomForestTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,44 @@
use Phpml\Classification\DecisionTree;
use Phpml\Classification\Ensemble\RandomForest;
use Phpml\Classification\NaiveBayes;
use Throwable;
use Phpml\Exception\InvalidArgumentException;

class RandomForestTest extends BaggingTest
{
public function testOtherBaseClassifier(): void
public function testThrowExceptionWithInvalidClassifier(): void
{
try {
$classifier = new RandomForest();
$classifier->setClassifer(NaiveBayes::class);
$this->assertEquals(0, 1);
} catch (Throwable $ex) {
$this->assertEquals(1, 1);
}
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('RandomForest can only use DecisionTree as base classifier');

$classifier = new RandomForest();
$classifier->setClassifer(NaiveBayes::class);
}

public function testThrowExceptionWithInvalidFeatureSubsetRatioType(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Feature subset ratio must be a string or a float');

$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio(1);
}

public function testThrowExceptionWithInvalidFeatureSubsetRatioFloat(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('When a float is given, feature subset ratio should be between 0.1 and 1.0');

$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio(1.1);
}

public function testThrowExceptionWithInvalidFeatureSubsetRatioString(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");

$classifier = new RandomForest();
$classifier->setFeatureSubsetRatio('pow');
}

protected function getClassifier($numBaseClassifiers = 50)
Expand Down

0 comments on commit 941d240

Please sign in to comment.