Skip to content

Commit

Permalink
create StratifiedRandomSplit for cross validation
Browse files Browse the repository at this point in the history
  • Loading branch information
akondas committed Jul 10, 2016
1 parent 0213208 commit f04cc04
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 79 deletions.
83 changes: 4 additions & 79 deletions src/Phpml/CrossValidation/RandomSplit.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,26 @@
namespace Phpml\CrossValidation;

use Phpml\Dataset\Dataset;
use Phpml\Exception\InvalidArgumentException;

class RandomSplit
class RandomSplit extends Split
{
/**
* @var array
*/
private $trainSamples = [];

/**
* @var array
*/
private $testSamples = [];

/**
* @var array
*/
private $trainLabels = [];

/**
* @var array
*/
private $testLabels = [];

/**
* @param Dataset $dataset
* @param float $testSize
* @param int $seed
*
* @throws InvalidArgumentException
*/
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
protected function splitDataset(Dataset $dataset, float $testSize)
{
if (0 >= $testSize || 1 <= $testSize) {
throw InvalidArgumentException::percentNotInRange('testSize');
}
$this->seedGenerator($seed);

$samples = $dataset->getSamples();
$labels = $dataset->getTargets();
$datasetSize = count($samples);
$testCount = count($this->testSamples);

for ($i = $datasetSize; $i > 0; --$i) {
$key = mt_rand(0, $datasetSize - 1);
$setName = count($this->testSamples) / $datasetSize >= $testSize ? 'train' : 'test';
$setName = (count($this->testSamples) - $testCount) / $datasetSize >= $testSize ? 'train' : 'test';

$this->{$setName.'Samples'}[] = $samples[$key];
$this->{$setName.'Labels'}[] = $labels[$key];

$samples = array_values($samples);
$labels = array_values($labels);
}
}

/**
* @return array
*/
public function getTrainSamples()
{
return $this->trainSamples;
}

/**
* @return array
*/
public function getTestSamples()
{
return $this->testSamples;
}

/**
* @return array
*/
public function getTrainLabels()
{
return $this->trainLabels;
}

/**
* @return array
*/
public function getTestLabels()
{
return $this->testLabels;
}

/**
* @param int|null $seed
*/
private function seedGenerator(int $seed = null)
{
if (null === $seed) {
mt_srand();
} else {
mt_srand($seed);
}
}
}
94 changes: 94 additions & 0 deletions src/Phpml/CrossValidation/Split.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
<?php

declare (strict_types = 1);

namespace Phpml\CrossValidation;

use Phpml\Dataset\Dataset;
use Phpml\Exception\InvalidArgumentException;

abstract class Split
{
/**
* @var array
*/
protected $trainSamples = [];

/**
* @var array
*/
protected $testSamples = [];

/**
* @var array
*/
protected $trainLabels = [];

/**
* @var array
*/
protected $testLabels = [];

/**
* @param Dataset $dataset
* @param float $testSize
* @param int $seed
*
* @throws InvalidArgumentException
*/
public function __construct(Dataset $dataset, float $testSize = 0.3, int $seed = null)
{
if (0 >= $testSize || 1 <= $testSize) {
throw InvalidArgumentException::percentNotInRange('testSize');
}
$this->seedGenerator($seed);

$this->splitDataset($dataset, $testSize);
}

abstract protected function splitDataset(Dataset $dataset, float $testSize);

/**
* @return array
*/
public function getTrainSamples()
{
return $this->trainSamples;
}

/**
* @return array
*/
public function getTestSamples()
{
return $this->testSamples;
}

/**
* @return array
*/
public function getTrainLabels()
{
return $this->trainLabels;
}

/**
* @return array
*/
public function getTestLabels()
{
return $this->testLabels;
}

/**
* @param int|null $seed
*/
protected function seedGenerator(int $seed = null)
{
if (null === $seed) {
mt_srand();
} else {
mt_srand($seed);
}
}
}
62 changes: 62 additions & 0 deletions src/Phpml/CrossValidation/StratifiedRandomSplit.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
<?php

declare (strict_types = 1);

namespace Phpml\CrossValidation;

use Phpml\Dataset\ArrayDataset;
use Phpml\Dataset\Dataset;

class StratifiedRandomSplit extends RandomSplit
{
/**
* @param Dataset $dataset
* @param float $testSize
*/
protected function splitDataset(Dataset $dataset, float $testSize)
{
$datasets = $this->splitByTarget($dataset);

foreach ($datasets as $targetSet) {
parent::splitDataset($targetSet, $testSize);
}
}

/**
* @param Dataset $dataset
*
* @return Dataset[]|array
*/
private function splitByTarget(Dataset $dataset): array
{
$targets = $dataset->getTargets();
$samples = $dataset->getSamples();

$uniqueTargets = array_unique($targets);
$split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), []));

foreach ($samples as $key => $sample) {
$split[$targets[$key]][] = $sample;
}

$datasets = $this->createDatasets($uniqueTargets, $split);

return $datasets;
}

/**
* @param array $uniqueTargets
* @param array $split
*
* @return array
*/
private function createDatasets(array $uniqueTargets, array $split): array
{
$datasets = [];
foreach ($uniqueTargets as $target) {
$datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target));
}

return $datasets;
}
}
65 changes: 65 additions & 0 deletions tests/Phpml/CrossValidation/StratifiedRandomSplitTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
<?php

declare (strict_types = 1);

namespace tests\Phpml\CrossValidation;

use Phpml\CrossValidation\StratifiedRandomSplit;
use Phpml\Dataset\ArrayDataset;

class StratifiedRandomSplitTest extends \PHPUnit_Framework_TestCase
{
public function testDatasetStratifiedRandomSplitWithEvenDistribution()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = ['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b']
);

$split = new StratifiedRandomSplit($dataset, 0.5);

$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'a'));
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 'b'));

$split = new StratifiedRandomSplit($dataset, 0.25);

$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'a'));
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 'b'));
}

public function testDatasetStratifiedRandomSplitWithEvenDistributionAndNumericTargets()
{
$dataset = new ArrayDataset(
$samples = [[1], [2], [3], [4], [5], [6], [7], [8]],
$labels = [1, 2, 1, 2, 1, 2, 1, 2]
);

$split = new StratifiedRandomSplit($dataset, 0.5);

$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 1));
$this->assertEquals(2, $this->countSamplesByTarget($split->getTestLabels(), 2));

$split = new StratifiedRandomSplit($dataset, 0.25);

$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 1));
$this->assertEquals(1, $this->countSamplesByTarget($split->getTestLabels(), 2));
}

/**
* @param $splitTargets
* @param $countTarget
*
* @return int
*/
private function countSamplesByTarget($splitTargets, $countTarget): int
{
$count = 0;
foreach ($splitTargets as $target) {
if ($target === $countTarget) {
++$count;
}
}

return $count;
}
}

0 comments on commit f04cc04

Please sign in to comment.