Skip to content

Commit

Permalink
implement fit fot TokenCountVectorizer
Browse files Browse the repository at this point in the history
  • Loading branch information
akondas committed Jun 16, 2016
1 parent be74233 commit 424519c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 40 deletions.
58 changes: 22 additions & 36 deletions src/Phpml/FeatureExtraction/TokenCountVectorizer.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ class TokenCountVectorizer implements Transformer
*/
private $vocabulary;

/**
* @var array
*/
private $tokens;

/**
* @var array
*/
Expand All @@ -51,21 +46,19 @@ public function __construct(Tokenizer $tokenizer, float $minDF = 0)
*/
public function fit(array $samples)
{
// TODO: Implement fit() method.
$this->buildVocabulary($samples);
}

/**
* @param array $samples
*/
public function transform(array &$samples)
{
$this->buildVocabulary($samples);

foreach ($samples as $index => $sample) {
$samples[$index] = $this->transformSample($index);
foreach ($samples as &$sample) {
$this->transformSample($sample);
}

$samples = $this->checkDocumentFrequency($samples);
$this->checkDocumentFrequency($samples);
}

/**
Expand All @@ -86,28 +79,27 @@ private function buildVocabulary(array &$samples)
foreach ($tokens as $token) {
$this->addTokenToVocabulary($token);
}
$this->tokens[$index] = $tokens;
}
}

/**
* @param int $index
*
* @return array
* @param string $sample
*/
private function transformSample(int $index)
private function transformSample(string &$sample)
{
$counts = [];
$tokens = $this->tokens[$index];
$tokens = $this->tokenizer->tokenize($sample);

foreach ($tokens as $token) {
$index = $this->getTokenIndex($token);
$this->updateFrequency($token);
if (!isset($counts[$index])) {
$counts[$index] = 0;
}
if(false !== $index) {
$this->updateFrequency($token);
if (!isset($counts[$index])) {
$counts[$index] = 0;
}

++$counts[$index];
++$counts[$index];
}
}

foreach ($this->vocabulary as $index) {
Expand All @@ -116,17 +108,17 @@ private function transformSample(int $index)
}
}

return $counts;
$sample = $counts;
}

/**
* @param string $token
*
* @return int
* @return int|bool
*/
private function getTokenIndex(string $token): int
private function getTokenIndex(string $token)
{
return $this->vocabulary[$token];
return isset($this->vocabulary[$token]) ? $this->vocabulary[$token] : false;
}

/**
Expand Down Expand Up @@ -156,31 +148,25 @@ private function updateFrequency(string $token)
*
* @return array
*/
private function checkDocumentFrequency(array $samples)
private function checkDocumentFrequency(array &$samples)
{
if ($this->minDF > 0) {
$beyondMinimum = $this->getBeyondMinimumIndexes(count($samples));
foreach ($samples as $index => $sample) {
$samples[$index] = $this->resetBeyondMinimum($sample, $beyondMinimum);
foreach ($samples as &$sample) {
$this->resetBeyondMinimum($sample, $beyondMinimum);
}
}

return $samples;
}

/**
* @param array $sample
* @param array $beyondMinimum
*
* @return array
*/
private function resetBeyondMinimum(array $sample, array $beyondMinimum)
private function resetBeyondMinimum(array &$sample, array $beyondMinimum)
{
foreach ($beyondMinimum as $index) {
$sample[$index] = 0;
}

return $sample;
}

/**
Expand Down
14 changes: 10 additions & 4 deletions tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ public function testTokenCountVectorizerWithWhitespaceTokenizer()
];

$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer());
$vectorizer->transform($samples);

$this->assertEquals($tokensCounts, $samples);
$vectorizer->fit($samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());

$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);
}

public function testMinimumDocumentTokenCountFrequency()
Expand Down Expand Up @@ -69,11 +71,14 @@ public function testMinimumDocumentTokenCountFrequency()
];

$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5);
$vectorizer->transform($samples);

$this->assertEquals($tokensCounts, $samples);
$vectorizer->fit($samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());

$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);


// word at least once in all samples
$samples = [
'Lorem ipsum dolor sit amet',
Expand All @@ -88,6 +93,7 @@ public function testMinimumDocumentTokenCountFrequency()
];

$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1);
$vectorizer->fit($samples);
$vectorizer->transform($samples);

$this->assertEquals($tokensCounts, $samples);
Expand Down

0 comments on commit 424519c

Please sign in to comment.