Skip to content

Commit

Permalink
Implement OneHotEncoder (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
akondas authored May 15, 2019
1 parent 3baf152 commit 4590d5c
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 1 deletion.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
## [0.9.0] - Unreleased
### Added
- [Preprocessing] Implement LabelEncoder
- [Preprocessing] Implement ColumnFilter
- [Preprocessing] Implement LambdaTransformer
- [Preprocessing] Implement NumberConverter
- [Preprocessing] Implement OneHotEncoder
- [Workflow] Implement FeatureUnion
- [Metric] Add Regression metrics: meanSquaredError, meanSquaredLogarithmicError, meanAbsoluteError, medianAbsoluteError, r2Score, maxError
- [Regression] Implement DecisionTreeRegressor

## [0.8.0] - 2019-03-20
### Added
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Public datasets are available in a separate repository [php-ai/php-ml-datasets](
* LambdaTransformer
* NumberConverter
* ColumnFilter
* OneHotEncoder
* Feature Extraction
* [Token Count Vectorizer](http://php-ml.readthedocs.io/en/latest/machine-learning/feature-extraction/token-count-vectorizer/)
* NGramTokenizer
Expand Down
66 changes: 66 additions & 0 deletions src/Preprocessing/OneHotEncoder.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
<?php

declare(strict_types=1);

namespace Phpml\Preprocessing;

use Phpml\Exception\InvalidArgumentException;

final class OneHotEncoder implements Preprocessor
{
/**
* @var bool
*/
private $ignoreUnknown;

/**
* @var array
*/
private $categories = [];

public function __construct(bool $ignoreUnknown = false)
{
$this->ignoreUnknown = $ignoreUnknown;
}

public function fit(array $samples, ?array $targets = null): void
{
foreach (array_keys(array_values(current($samples))) as $column) {
$this->fitColumn($column, array_values(array_unique(array_column($samples, $column))));
}
}

public function transform(array &$samples, ?array &$targets = null): void
{
foreach ($samples as &$sample) {
$sample = $this->transformSample(array_values($sample));
}
}

private function fitColumn(int $column, array $values): void
{
$count = count($values);
foreach ($values as $index => $value) {
$map = array_fill(0, $count, 0);
$map[$index] = 1;
$this->categories[$column][$value] = $map;
}
}

private function transformSample(array $sample): array
{
$encoded = [];
foreach ($sample as $column => $feature) {
if (!isset($this->categories[$column][$feature]) && !$this->ignoreUnknown) {
throw new InvalidArgumentException(sprintf('Missing category "%s" for column %s in trained encoder', $feature, $column));
}

$encoded = array_merge(
$encoded,
$this->categories[$column][$feature] ?? array_fill(0, count($this->categories[$column]), 0)
);
}

return $encoded;
}
}
66 changes: 66 additions & 0 deletions tests/Preprocessing/OneHotEncoderTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
<?php

declare(strict_types=1);

namespace Phpml\Tests\Preprocessing;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Preprocessing\OneHotEncoder;
use PHPUnit\Framework\TestCase;

final class OneHotEncoderTest extends TestCase
{
public function testOneHotEncodingWithoutIgnoreUnknown(): void
{
$samples = [
['fish', 'New York', 'regression'],
['dog', 'New York', 'regression'],
['fish', 'Vancouver', 'classification'],
['dog', 'Vancouver', 'regression'],
];

$encoder = new OneHotEncoder();
$encoder->fit($samples);
$encoder->transform($samples);

self::assertEquals([
[1, 0, 1, 0, 1, 0],
[0, 1, 1, 0, 1, 0],
[1, 0, 0, 1, 0, 1],
[0, 1, 0, 1, 1, 0],
], $samples);
}

public function testThrowExceptionWhenUnknownCategory(): void
{
$encoder = new OneHotEncoder();
$encoder->fit([
['fish', 'New York', 'regression'],
['dog', 'New York', 'regression'],
['fish', 'Vancouver', 'classification'],
['dog', 'Vancouver', 'regression'],
]);
$samples = [['fish', 'New York', 'ka boom']];

$this->expectException(InvalidArgumentException::class);

$encoder->transform($samples);
}

public function testIgnoreMissingCategory(): void
{
$encoder = new OneHotEncoder(true);
$encoder->fit([
['fish', 'New York', 'regression'],
['dog', 'New York', 'regression'],
['fish', 'Vancouver', 'classification'],
['dog', 'Vancouver', 'regression'],
]);
$samples = [['ka', 'boom', 'riko']];
$encoder->transform($samples);

self::assertEquals([
[0, 0, 0, 0, 0, 0],
], $samples);
}
}

0 comments on commit 4590d5c

Please sign in to comment.