Skip to content

Commit

Permalink
Fix samples transformation in Pipeline training (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
maximecolin authored and akondas committed May 24, 2017
1 parent de50490 commit 2d3b44f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
17 changes: 5 additions & 12 deletions src/Phpml/Pipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ public function getEstimator()
*/
public function train(array $samples, array $targets)
{
$this->fitTransformers($samples);
$this->transformSamples($samples);
foreach ($this->transformers as $transformer) {
$transformer->fit($samples);
$transformer->transform($samples);
}

$this->estimator->train($samples, $targets);
}

Expand All @@ -84,16 +87,6 @@ public function predict(array $samples)
return $this->estimator->predict($samples);
}

/**
* @param array $samples
*/
private function fitTransformers(array &$samples)
{
foreach ($this->transformers as $transformer) {
$transformer->fit($samples);
}
}

/**
* @param array $samples
*/
Expand Down
39 changes: 39 additions & 0 deletions tests/Phpml/PipelineTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

use Phpml\Classification\SVC;
use Phpml\FeatureExtraction\TfIdfTransformer;
use Phpml\FeatureExtraction\TokenCountVectorizer;
use Phpml\Pipeline;
use Phpml\Preprocessing\Imputer;
use Phpml\Preprocessing\Normalizer;
use Phpml\Preprocessing\Imputer\Strategy\MostFrequentStrategy;
use Phpml\Regression\SVR;
use Phpml\Tokenization\WordTokenizer;
use PHPUnit\Framework\TestCase;

class PipelineTest extends TestCase
Expand Down Expand Up @@ -65,4 +67,41 @@ public function testPipelineWorkflow()

$this->assertEquals(4, $predicted[0]);
}

public function testPipelineTransformers()
{
$transformers = [
new TokenCountVectorizer(new WordTokenizer()),
new TfIdfTransformer()
];

$estimator = new SVC();

$samples = [
'Hello Paul',
'Hello Martin',
'Goodbye Tom',
'Hello John',
'Goodbye Alex',
'Bye Tony',
];

$targets = [
'greetings',
'greetings',
'farewell',
'greetings',
'farewell',
'farewell',
];

$pipeline = new Pipeline($transformers, $estimator);
$pipeline->train($samples, $targets);

$expected = ['greetings', 'farewell'];

$predicted = $pipeline->predict(['Hello Max', 'Goodbye Mark']);

$this->assertEquals($expected, $predicted);
}
}

0 comments on commit 2d3b44f

Please sign in to comment.