Skip to content

Commit

Permalink
Implement VarianceThreshold - simple baseline approach to feature sel…
Browse files Browse the repository at this point in the history
…ection. (#228)

* Add sum of squares deviations

* Calculate population variance

* Add VarianceThreshold - feature selection transformer

* Add docs about VarianceThreshold

* Add missing code for pipeline usage
  • Loading branch information
akondas authored Feb 10, 2018
1 parent 4b5d57f commit 3ba3591
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples](
* Cross Validation
* [Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/random-split/)
* [Stratified Random Split](http://php-ml.readthedocs.io/en/latest/machine-learning/cross-validation/stratified-random-split/)
* Feature Selection
* [Variance Threshold](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-selection/variance-threshold/)
* Preprocessing
* [Normalization](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/normalization/)
* [Imputation missing values](http://php-ml.readthedocs.io/en/latest/machine-learning/preprocessing/imputation-missing-values/)
Expand Down
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples](
* Cross Validation
* [Random Split](machine-learning/cross-validation/random-split.md)
* [Stratified Random Split](machine-learning/cross-validation/stratified-random-split.md)
* Feature Selection
* [Variance Threshold](machine-learning/feature-selection/variance-threshold.md)
* Preprocessing
* [Normalization](machine-learning/preprocessing/normalization.md)
* [Imputation missing values](machine-learning/preprocessing/imputation-missing-values.md)
Expand Down
60 changes: 60 additions & 0 deletions docs/machine-learning/feature-selection/variance-threshold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Variance Threshold

`VarianceThreshold` is a simple baseline approach to feature selection.
It removes all features whose variance doesn’t meet some threshold.
By default, it removes all zero-variance features, i.e. features that have the same value in all samples.

## Constructor Parameters

* $threshold (float) - features with a variance lower than this threshold will be removed (default 0.0)

```php
use Phpml\FeatureSelection\VarianceThreshold;

$transformer = new VarianceThreshold(0.15);
```

## Example of use

As an example, suppose that we have a dataset with boolean features and
we want to remove all features that are either one or zero (on or off)
in more than 80% of the samples.
Boolean features are Bernoulli random variables, and the variance of such
variables is given by
```
Var[X] = p(1 - p)
```
so we can select using the threshold .8 * (1 - .8):

```php
use Phpml\FeatureSelection\VarianceThreshold;

$samples = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]];
$transformer = new VarianceThreshold(0.8 * (1 - 0.8));

$transformer->fit($samples);
$transformer->transform($samples);

/*
$samples = [[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]];
*/
```

## Pipeline

`VarianceThreshold` implements `Transformer` interface so it can be used as part of pipeline:

```php
use Phpml\FeatureSelection\VarianceThreshold;
use Phpml\Classification\SVC;
use Phpml\FeatureExtraction\TfIdfTransformer;
use Phpml\Pipeline;

$transformers = [
new TfIdfTransformer(),
new VarianceThreshold(0.1)
];
$estimator = new SVC();

$pipeline = new Pipeline($transformers, $estimator);
```
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pages:
- Cross Validation:
- RandomSplit: machine-learning/cross-validation/random-split.md
- Stratified Random Split: machine-learning/cross-validation/stratified-random-split.md
- Feature Selection:
- VarianceThreshold: machine-learning/feature-selection/variance-threshold.md
- Preprocessing:
- Normalization: machine-learning/preprocessing/normalization.md
- Imputation missing values: machine-learning/preprocessing/imputation-missing-values.md
Expand Down
59 changes: 59 additions & 0 deletions src/FeatureSelection/VarianceThreshold.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<?php

declare(strict_types=1);

namespace Phpml\FeatureSelection;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Math\Matrix;
use Phpml\Math\Statistic\Variance;
use Phpml\Transformer;

final class VarianceThreshold implements Transformer
{
/**
* @var float
*/
private $threshold;

/**
* @var array
*/
private $variances = [];

/**
* @var array
*/
private $keepColumns = [];

public function __construct(float $threshold = 0.0)
{
if ($threshold < 0) {
throw new InvalidArgumentException('Threshold can\'t be lower than zero');
}

$this->threshold = $threshold;
$this->variances = [];
$this->keepColumns = [];
}

public function fit(array $samples): void
{
$this->variances = array_map(function (array $column) {
return Variance::population($column);
}, Matrix::transposeArray($samples));

foreach ($this->variances as $column => $variance) {
if ($variance > $this->threshold) {
$this->keepColumns[$column] = true;
}
}
}

public function transform(array &$samples): void
{
foreach ($samples as &$sample) {
$sample = array_values(array_intersect_key($sample, $this->keepColumns));
}
}
}
39 changes: 29 additions & 10 deletions src/Math/Statistic/StandardDeviation.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,24 @@
class StandardDeviation
{
/**
* @param array|float[] $a
*
* @throws InvalidArgumentException
* @param array|float[]|int[] $numbers
*/
public static function population(array $a, bool $sample = true): float
public static function population(array $numbers, bool $sample = true): float
{
if (empty($a)) {
if (empty($numbers)) {
throw InvalidArgumentException::arrayCantBeEmpty();
}

$n = count($a);
$n = count($numbers);

if ($sample && $n === 1) {
throw InvalidArgumentException::arraySizeToSmall(2);
}

$mean = Mean::arithmetic($a);
$mean = Mean::arithmetic($numbers);
$carry = 0.0;
foreach ($a as $val) {
$d = $val - $mean;
$carry += $d * $d;
foreach ($numbers as $val) {
$carry += ($val - $mean) ** 2;
}

if ($sample) {
Expand All @@ -38,4 +35,26 @@ public static function population(array $a, bool $sample = true): float

return sqrt((float) ($carry / $n));
}

/**
* Sum of squares deviations
* ∑⟮xᵢ - μ⟯²
*
* @param array|float[]|int[] $numbers
*/
public static function sumOfSquares(array $numbers): float
{
if (empty($numbers)) {
throw InvalidArgumentException::arrayCantBeEmpty();
}

$mean = Mean::arithmetic($numbers);

return array_sum(array_map(
function ($val) use ($mean) {
return ($val - $mean) ** 2;
},
$numbers
));
}
}
27 changes: 27 additions & 0 deletions src/Math/Statistic/Variance.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<?php

declare(strict_types=1);

namespace Phpml\Math\Statistic;

/**
* In probability theory and statistics, variance is the expectation of the squared deviation of a random variable from its mean.
* Informally, it measures how far a set of (random) numbers are spread out from their average value
* https://en.wikipedia.org/wiki/Variance
*/
final class Variance
{
/**
* Population variance
* Use when all possible observations of the system are present.
* If used with a subset of data (sample variance), it will be a biased variance.
*
* ∑⟮xᵢ - μ⟯²
* σ² = ----------
* N
*/
public static function population(array $population): float
{
return StandardDeviation::sumOfSquares($population) / count($population);
}
}
39 changes: 39 additions & 0 deletions tests/FeatureSelection/VarianceThresholdTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<?php

declare(strict_types=1);

namespace Phpml\Tests\FeatureSelection;

use Phpml\Exception\InvalidArgumentException;
use Phpml\FeatureSelection\VarianceThreshold;
use PHPUnit\Framework\TestCase;

final class VarianceThresholdTest extends TestCase
{
public function testVarianceThreshold(): void
{
$samples = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]];
$transformer = new VarianceThreshold(0.8 * (1 - 0.8)); // 80% of samples - boolean features are Bernoulli random variables
$transformer->fit($samples);
$transformer->transform($samples);

// expecting to remove first column
self::assertEquals([[0, 1], [1, 0], [0, 0], [1, 1], [1, 0], [1, 1]], $samples);
}

public function testVarianceThresholdWithZeroThreshold(): void
{
$samples = [[0, 2, 0, 3], [0, 1, 4, 3], [0, 1, 1, 3]];
$transformer = new VarianceThreshold();
$transformer->fit($samples);
$transformer->transform($samples);

self::assertEquals([[2, 0], [1, 4], [1, 1]], $samples);
}

public function testThrowExceptionWhenThresholdBelowZero(): void
{
$this->expectException(InvalidArgumentException::class);
new VarianceThreshold(-0.1);
}
}
25 changes: 25 additions & 0 deletions tests/Math/Statistic/StandardDeviationTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,29 @@ public function testThrowExceptionOnToSmallArray(): void
$this->expectException(InvalidArgumentException::class);
StandardDeviation::population([1]);
}

/**
* @dataProvider dataProviderForSumOfSquaresDeviations
*/
public function testSumOfSquares(array $numbers, float $sum): void
{
self::assertEquals($sum, StandardDeviation::sumOfSquares($numbers), '', 0.0001);
}

public function dataProviderForSumOfSquaresDeviations(): array
{
return [
[[3, 6, 7, 11, 12, 13, 17], 136.8571],
[[6, 11, 12, 14, 15, 20, 21], 162.8571],
[[1, 2, 3, 6, 7, 11, 12], 112],
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 82.5],
[[34, 253, 754, 2342, 75, 23, 876, 4, 1, -34, -345, 754, -377, 3, 0], 6453975.7333],
];
}

public function testThrowExceptionOnEmptyArraySumOfSquares(): void
{
$this->expectException(InvalidArgumentException::class);
StandardDeviation::sumOfSquares([]);
}
}
34 changes: 34 additions & 0 deletions tests/Math/Statistic/VarianceTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<?php

declare(strict_types=1);

namespace Phpml\Tests\Math\Statistic;

use Phpml\Math\Statistic\Variance;
use PHPUnit\Framework\TestCase;

final class VarianceTest extends TestCase
{
/**
* @dataProvider dataProviderForPopulationVariance
*/
public function testVarianceFromInt(array $numbers, float $variance): void
{
self::assertEquals($variance, Variance::population($numbers), '', 0.001);
}

public function dataProviderForPopulationVariance()
{
return [
[[0, 0, 0, 0, 0, 1], 0.138],
[[-11, 0, 10, 20, 30], 208.16],
[[7, 8, 9, 10, 11, 12, 13], 4.0],
[[300, 570, 170, 730, 300], 41944],
[[-4, 2, 7, 8, 3], 18.16],
[[3, 7, 34, 25, 46, 7754, 3, 6], 6546331.937],
[[4, 6, 1, 1, 1, 1, 2, 2, 1, 3], 2.56],
[[-3732, 5, 27, 9248, -174], 18741676.56],
[[-554, -555, -554, -554, -555, -555, -556], 0.4897],
];
}
}

0 comments on commit 3ba3591

Please sign in to comment.