Skip to content

Commit

Permalink
Fix pipeline transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
akondas committed Feb 14, 2018
1 parent 998879b commit b4b190d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/Pipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public function getEstimator(): Estimator
public function train(array $samples, array $targets): void
{
foreach ($this->transformers as $transformer) {
$transformer->fit($samples);
$transformer->fit($samples, $targets);
$transformer->transform($samples);
}

Expand Down
14 changes: 14 additions & 0 deletions tests/PipelineTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
use Phpml\Classification\SVC;
use Phpml\FeatureExtraction\TfIdfTransformer;
use Phpml\FeatureExtraction\TokenCountVectorizer;
use Phpml\FeatureSelection\ScoringFunction\ANOVAFValue;
use Phpml\FeatureSelection\SelectKBest;
use Phpml\ModelManager;
use Phpml\Pipeline;
use Phpml\Preprocessing\Imputer;
Expand Down Expand Up @@ -106,6 +108,18 @@ public function testPipelineTransformers(): void
$this->assertEquals($expected, $predicted);
}

public function testPipelineTransformersWithTargets() : void
{
$samples = [[1, 2, 1], [1, 3, 4], [5, 2, 1], [1, 3, 3], [1, 3, 4], [0, 3, 5]];
$targets = ['a', 'a', 'a', 'b', 'b', 'b'];

$pipeline = new Pipeline([$selector = new SelectKBest(2)], new SVC());
$pipeline->train($samples, $targets);

self::assertEquals([1.47058823, 4.0, 3.0], $selector->scores(), '', 0.00000001);
self::assertEquals(['b'], $pipeline->predict([[1, 3, 5]]));
}

public function testSaveAndRestore(): void
{
$pipeline = new Pipeline([
Expand Down

0 comments on commit b4b190d

Please sign in to comment.