forked from jorgecasas/php-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create StratifiedRandomSplit for cross validation
- Loading branch information
Showing
4 changed files
with
225 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
<?php | ||
|
||
declare (strict_types = 1); | ||
|
||
namespace Phpml\CrossValidation; | ||
|
||
use Phpml\Dataset\Dataset; | ||
use Phpml\Exception\InvalidArgumentException; | ||
|
||
abstract class Split | ||
{ | ||
/** | ||
* @var array | ||
*/ | ||
protected $trainSamples = []; | ||
|
||
/** | ||
* @var array | ||
*/ | ||
protected $testSamples = []; | ||
|
||
/** | ||
* @var array | ||
*/ | ||
protected $trainLabels = []; | ||
|
||
/** | ||
* @var array | ||
*/ | ||
protected $testLabels = []; | ||
|
||
/** | ||
* @param Dataset $dataset | ||
* @param float $testSize | ||
* @param int $seed | ||
* | ||
* @throws InvalidArgumentException | ||
*/ | ||
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null) | ||
{ | ||
if (0 >= $testSize || 1 <= $testSize) { | ||
throw InvalidArgumentException::percentNotInRange('testSize'); | ||
} | ||
$this->seedGenerator($seed); | ||
|
||
$this->splitDataset($dataset, $testSize); | ||
} | ||
|
||
abstract protected function splitDataset(Dataset $dataset, float $testSize); | ||
|
||
/** | ||
* @return array | ||
*/ | ||
public function getTrainSamples() | ||
{ | ||
return $this->trainSamples; | ||
} | ||
|
||
/** | ||
* @return array | ||
*/ | ||
public function getTestSamples() | ||
{ | ||
return $this->testSamples; | ||
} | ||
|
||
/** | ||
* @return array | ||
*/ | ||
public function getTrainLabels() | ||
{ | ||
return $this->trainLabels; | ||
} | ||
|
||
/** | ||
* @return array | ||
*/ | ||
public function getTestLabels() | ||
{ | ||
return $this->testLabels; | ||
} | ||
|
||
/** | ||
* @param int|null $seed | ||
*/ | ||
protected function seedGenerator(int $seed = null) | ||
{ | ||
if (null === $seed) { | ||
mt_srand(); | ||
} else { | ||
mt_srand($seed); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
<?php | ||
|
||
declare (strict_types = 1); | ||
|
||
namespace Phpml\CrossValidation; | ||
|
||
use Phpml\Dataset\ArrayDataset; | ||
use Phpml\Dataset\Dataset; | ||
|
||
class StratifiedRandomSplit extends RandomSplit | ||
{ | ||
/** | ||
* @param Dataset $dataset | ||
* @param float $testSize | ||
*/ | ||
protected function splitDataset(Dataset $dataset, float $testSize) | ||
{ | ||
$datasets = $this->splitByTarget($dataset); | ||
|
||
foreach ($datasets as $targetSet) { | ||
parent::splitDataset($targetSet, $testSize); | ||
} | ||
} | ||
|
||
/** | ||
* @param Dataset $dataset | ||
* | ||
* @return Dataset[]|array | ||
*/ | ||
private function splitByTarget(Dataset $dataset): array | ||
{ | ||
$targets = $dataset->getTargets(); | ||
$samples = $dataset->getSamples(); | ||
|
||
$uniqueTargets = array_unique($targets); | ||
$split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), [])); | ||
|
||
foreach ($samples as $key => $sample) { | ||
$split[$targets[$key]][] = $sample; | ||
} | ||
|
||
$datasets = $this->createDatasets($uniqueTargets, $split); | ||
|
||
return $datasets; | ||
} | ||
|
||
/** | ||
* @param array $uniqueTargets | ||
* @param array $split | ||
* | ||
* @return array | ||
*/ | ||
private function createDatasets(array $uniqueTargets, array $split): array | ||
{ | ||
$datasets = []; | ||
foreach ($uniqueTargets as $target) { | ||
$datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target)); | ||
} | ||
|
||
return $datasets; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
<?php | ||
|
||
declare (strict_types = 1); | ||
|
||
namespace tests\Phpml\CrossValidation; | ||
|
||
use Phpml\CrossValidation\StratifiedRandomSplit; | ||
use Phpml\Dataset\ArrayDataset; | ||
|
||
class StratifiedRandomSplitTest extends \PHPUnit_Framework_TestCase | ||
{ | ||
public function testDatasetStratifiedRandomSplitWithEvenDistribution() | ||
{ | ||
$dataset = new ArrayDataset( | ||
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]], | ||
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b'] | ||
); | ||
|
||
$split = new StratifiedRandomSplit($dataset, 0.5); | ||
|
||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'a')); | ||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'b')); | ||
|
||
$split = new StratifiedRandomSplit($dataset, 0.25); | ||
|
||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'a')); | ||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'b')); | ||
} | ||
|
||
public function testDatasetStratifiedRandomSplitWithEvenDistributionAndNumericTargets() | ||
{ | ||
$dataset = new ArrayDataset( | ||
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]], | ||
$labels = [1, 2, 1, 2, 1, 2, 1, 2] | ||
); | ||
|
||
$split = new StratifiedRandomSplit($dataset, 0.5); | ||
|
||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 1)); | ||
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 2)); | ||
|
||
$split = new StratifiedRandomSplit($dataset, 0.25); | ||
|
||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 1)); | ||
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 2)); | ||
} | ||
|
||
/** | ||
* @param $splitTargets | ||
* @param $countTarget | ||
* | ||
* @return int | ||
*/ | ||
private function countSamplesByTarget($splitTargets, $countTarget): int | ||
{ | ||
$count = 0; | ||
foreach ($splitTargets as $target) { | ||
if ($target === $countTarget) { | ||
++$count; | ||
} | ||
} | ||
|
||
return $count; | ||
} | ||
} |