Skip to content

Commit

Permalink
Partial training base (#78)
Browse files Browse the repository at this point in the history
* Cost values for multiclass OneVsRest uses

* Partial training interface

* Reduce linear classifiers memory usage

* Testing partial training and isolated training

* Partial trainer naming switched to incremental estimator

Other changes according to review's feedback.

* Clean optimization data once optimize is finished

* Abstract resetBinary
  • Loading branch information
dmonllao authored and akondas committed Apr 19, 2017
1 parent c0463ae commit e1854d4
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 107 deletions.
7 changes: 5 additions & 2 deletions src/Phpml/Classification/Linear/Adaline.php
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ public function __construct(float $learningRate = 0.001, int $maxIterations = 10
/**
* Adapts the weights with respect to given samples and targets
* by use of gradient descent learning rule
*
* @param array $samples
* @param array $targets
*/
protected function runTraining()
protected function runTraining(array $samples, array $targets)
{
// The cost function is the sum of squares
$callback = function ($weights, $sample, $target) {
Expand All @@ -69,6 +72,6 @@ protected function runTraining()

$isBatch = $this->trainingType == self::BATCH_TRAINING;

return parent::runGradientDescent($callback, $isBatch);
return parent::runGradientDescent($samples, $targets, $callback, $isBatch);
}
}
54 changes: 32 additions & 22 deletions src/Phpml/Classification/Linear/DecisionStump.php
Original file line number Diff line number Diff line change
Expand Up @@ -89,35 +89,33 @@ public function __construct(int $columnIndex = self::AUTO_SELECT)
* @param array $targets
* @throws \Exception
*/
protected function trainBinary(array $samples, array $targets)
protected function trainBinary(array $samples, array $targets, array $labels)
{
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
$this->binaryLabels = array_keys(array_count_values($this->targets));
$this->featureCount = count($this->samples[0]);
$this->binaryLabels = $labels;
$this->featureCount = count($samples[0]);

// If a column index is given, it should be among the existing columns
if ($this->givenColumnIndex > count($this->samples[0]) - 1) {
if ($this->givenColumnIndex > count($samples[0]) - 1) {
$this->givenColumnIndex = self::AUTO_SELECT;
}

// Check the size of the weights given.
// If none given, then assign 1 as a weight to each sample
if ($this->weights) {
$numWeights = count($this->weights);
if ($numWeights != count($this->samples)) {
if ($numWeights != count($samples)) {
throw new \Exception("Number of sample weights does not match with number of samples");
}
} else {
$this->weights = array_fill(0, count($this->samples), 1);
$this->weights = array_fill(0, count($samples), 1);
}

// Determine type of each column as either "continuous" or "nominal"
$this->columnTypes = DecisionTree::getColumnTypes($this->samples);
$this->columnTypes = DecisionTree::getColumnTypes($samples);

// Try to find the best split in the columns of the dataset
// by calculating error rate for each split point in each column
$columns = range(0, count($this->samples[0]) - 1);
$columns = range(0, count($samples[0]) - 1);
if ($this->givenColumnIndex != self::AUTO_SELECT) {
$columns = [$this->givenColumnIndex];
}
Expand All @@ -128,9 +126,9 @@ protected function trainBinary(array $samples, array $targets)
'trainingErrorRate' => 1.0];
foreach ($columns as $col) {
if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
$split = $this->getBestNumericalSplit($col);
$split = $this->getBestNumericalSplit($samples, $targets, $col);
} else {
$split = $this->getBestNominalSplit($col);
$split = $this->getBestNominalSplit($samples, $targets, $col);
}

if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
Expand Down Expand Up @@ -161,13 +159,15 @@ public function setNumericalSplitCount(float $count)
/**
* Determines best split point for the given column
*
* @param array $samples
* @param array $targets
* @param int $col
*
* @return array
*/
protected function getBestNumericalSplit(int $col)
protected function getBestNumericalSplit(array $samples, array $targets, int $col)
{
$values = array_column($this->samples, $col);
$values = array_column($samples, $col);
// Trying all possible points may be accomplished in two general ways:
// 1- Try all values in the $samples array ($values)
// 2- Artificially split the range of values into several parts and try them
Expand All @@ -182,7 +182,7 @@ protected function getBestNumericalSplit(int $col)
// Before trying all possible split points, let's first try
// the average value for the cut point
$threshold = array_sum($values) / (float) count($values);
list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values);
list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
if ($split == null || $errorRate < $split['trainingErrorRate']) {
$split = ['value' => $threshold, 'operator' => $operator,
'prob' => $prob, 'column' => $col,
Expand All @@ -192,7 +192,7 @@ protected function getBestNumericalSplit(int $col)
// Try other possible points one by one
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
$threshold = (float)$step;
list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values);
list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
if ($errorRate < $split['trainingErrorRate']) {
$split = ['value' => $threshold, 'operator' => $operator,
'prob' => $prob, 'column' => $col,
Expand All @@ -205,21 +205,23 @@ protected function getBestNumericalSplit(int $col)
}

/**
* @param array $samples
* @param array $targets
* @param int $col
*
* @return array
*/
protected function getBestNominalSplit(int $col) : array
protected function getBestNominalSplit(array $samples, array $targets, int $col) : array
{
$values = array_column($this->samples, $col);
$values = array_column($samples, $col);
$valueCounts = array_count_values($values);
$distinctVals= array_keys($valueCounts);

$split = null;

foreach (['=', '!='] as $operator) {
foreach ($distinctVals as $val) {
list($errorRate, $prob) = $this->calculateErrorRate($val, $operator, $values);
list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values);

if ($split == null || $split['trainingErrorRate'] < $errorRate) {
$split = ['value' => $val, 'operator' => $operator,
Expand Down Expand Up @@ -260,13 +262,14 @@ protected function evaluate($leftValue, $operator, $rightValue)
* Calculates the ratio of wrong predictions based on the new threshold
* value given as the parameter
*
* @param array $targets
* @param float $threshold
* @param string $operator
* @param array $values
*
* @return array
*/
protected function calculateErrorRate(float $threshold, string $operator, array $values) : array
protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array
{
$wrong = 0.0;
$prob = [];
Expand All @@ -280,8 +283,8 @@ protected function calculateErrorRate(float $threshold, string $operator, array
$predicted = $rightLabel;
}

$target = $this->targets[$index];
if (strval($predicted) != strval($this->targets[$index])) {
$target = $targets[$index];
if (strval($predicted) != strval($targets[$index])) {
$wrong += $this->weights[$index];
}

Expand Down Expand Up @@ -340,6 +343,13 @@ protected function predictSampleBinary(array $sample)
return $this->binaryLabels[1];
}

/**
* @return void
*/
protected function resetBinary()
{
}

/**
* @return string
*/
Expand Down
23 changes: 14 additions & 9 deletions src/Phpml/Classification/Linear/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -123,34 +123,39 @@ public function setLambda(float $lambda)
/**
* Adapts the weights with respect to given samples and targets
* by use of selected solver
*
* @param array $samples
* @param array $targets
*/
protected function runTraining()
protected function runTraining(array $samples, array $targets)
{
$callback = $this->getCostFunction();

switch ($this->trainingType) {
case self::BATCH_TRAINING:
return $this->runGradientDescent($callback, true);
return $this->runGradientDescent($samples, $targets, $callback, true);

case self::ONLINE_TRAINING:
return $this->runGradientDescent($callback, false);
return $this->runGradientDescent($samples, $targets, $callback, false);

case self::CONJUGATE_GRAD_TRAINING:
return $this->runConjugateGradient($callback);
return $this->runConjugateGradient($samples, $targets, $callback);
}
}

/**
* Executes Conjugate Gradient method to optimize the
* weights of the LogReg model
*/
protected function runConjugateGradient(\Closure $gradientFunc)
protected function runConjugateGradient(array $samples, array $targets, \Closure $gradientFunc)
{
$optimizer = (new ConjugateGradient($this->featureCount))
->setMaxIterations($this->maxIterations);
if (empty($this->optimizer)) {
$this->optimizer = (new ConjugateGradient($this->featureCount))
->setMaxIterations($this->maxIterations);
}

$this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
$this->costValues = $optimizer->getCostValues();
$this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
$this->costValues = $this->optimizer->getCostValues();
}

/**
Expand Down
80 changes: 50 additions & 30 deletions src/Phpml/Classification/Linear/Perceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,17 @@
use Phpml\Helper\Optimizer\GD;
use Phpml\Classification\Classifier;
use Phpml\Preprocessing\Normalizer;
use Phpml\IncrementalEstimator;
use Phpml\Helper\PartiallyTrainable;

class Perceptron implements Classifier
class Perceptron implements Classifier, IncrementalEstimator
{
use Predictable, OneVsRest;

/**
* @var array
*/
protected $samples = [];

/**
* @var array
* @var \Phpml\Helper\Optimizer\Optimizer
*/
protected $targets = [];
protected $optimizer;

/**
* @var array
Expand Down Expand Up @@ -93,32 +90,47 @@ public function __construct(float $learningRate = 0.001, int $maxIterations = 10
$this->maxIterations = $maxIterations;
}

/**
* @param array $samples
* @param array $targets
* @param array $labels
*/
public function partialTrain(array $samples, array $targets, array $labels = array())
{
return $this->trainByLabel($samples, $targets, $labels);
}

/**
* @param array $samples
* @param array $targets
* @param array $labels
*/
public function trainBinary(array $samples, array $targets)
public function trainBinary(array $samples, array $targets, array $labels)
{
$this->labels = array_keys(array_count_values($targets));
if (count($this->labels) > 2) {
throw new \Exception("Perceptron is for binary (two-class) classification only");
}

if ($this->normalizer) {
$this->normalizer->transform($samples);
}

// Set all target values to either -1 or 1
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
foreach ($targets as $target) {
$this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1;
$this->labels = [1 => $labels[0], -1 => $labels[1]];
foreach ($targets as $key => $target) {
$targets[$key] = strval($target) == strval($this->labels[1]) ? 1 : -1;
}

// Set samples and feature count vars
$this->samples = array_merge($this->samples, $samples);
$this->featureCount = count($this->samples[0]);
$this->featureCount = count($samples[0]);

$this->runTraining($samples, $targets);
}

$this->runTraining();
protected function resetBinary()
{
$this->labels = [];
$this->optimizer = null;
$this->featureCount = 0;
$this->weights = null;
$this->costValues = [];
}

/**
Expand Down Expand Up @@ -151,8 +163,11 @@ public function getCostValues()
/**
* Trains the perceptron model with Stochastic Gradient Descent optimization
* to get the correct set of weights
*
* @param array $samples
* @param array $targets
*/
protected function runTraining()
protected function runTraining(array $samples, array $targets)
{
// The cost function is the sum of squares
$callback = function ($weights, $sample, $target) {
Expand All @@ -165,25 +180,30 @@ protected function runTraining()
return [$error, $gradient];
};

$this->runGradientDescent($callback);
$this->runGradientDescent($samples, $targets, $callback);
}

/**
* Executes Stochastic Gradient Descent algorithm for
* Executes a Gradient Descent algorithm for
* the given cost function
*
* @param array $samples
* @param array $targets
*/
protected function runGradientDescent(\Closure $gradientFunc, bool $isBatch = false)
protected function runGradientDescent(array $samples, array $targets, \Closure $gradientFunc, bool $isBatch = false)
{
$class = $isBatch ? GD::class : StochasticGD::class;

$optimizer = (new $class($this->featureCount))
->setLearningRate($this->learningRate)
->setMaxIterations($this->maxIterations)
->setChangeThreshold(1e-6)
->setEarlyStop($this->enableEarlyStop);
if (empty($this->optimizer)) {
$this->optimizer = (new $class($this->featureCount))
->setLearningRate($this->learningRate)
->setMaxIterations($this->maxIterations)
->setChangeThreshold(1e-6)
->setEarlyStop($this->enableEarlyStop);
}

$this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
$this->costValues = $optimizer->getCostValues();
$this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
$this->costValues = $this->optimizer->getCostValues();
}

/**
Expand Down
Loading

0 comments on commit e1854d4

Please sign in to comment.